From 9ebf6f24197c0ce582f5e77fe148b1342bd92b60 Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Wed, 8 Nov 2023 21:47:59 -0500 Subject: [PATCH] working but not tested --- DESCRIPTION | 2 +- R/data-penguins_orsf.R | 5 +- R/orsf_R6.R | 186 +++++++++++++++++++++++------ R/orsf_summary.R | 136 +-------------------- man/penguins_orsf.Rd | 5 +- src/Forest.cpp | 8 +- src/TreeClassification.cpp | 5 +- src/utility.cpp | 14 +++ src/utility.h | 2 + tests/testthat/test-orsf_pd.R | 3 - tests/testthat/test-orsf_predict.R | 78 +----------- tests/testthat/test-orsf_summary.R | 3 - tests/testthat/test-unit_info.R | 139 +++++++++++++-------- 13 files changed, 276 insertions(+), 310 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 9c25cfbf..2dfebbcd 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: aorsf Title: Accelerated Oblique Random Survival Forests -Version: 0.1.1.9002 +Version: 0.1.1.9003 Authors@R: c( person(given = "Byron", family = "Jaeger", diff --git a/R/data-penguins_orsf.R b/R/data-penguins_orsf.R index bb88a9f6..f5c22004 100644 --- a/R/data-penguins_orsf.R +++ b/R/data-penguins_orsf.R @@ -1,7 +1,8 @@ #' Size measurements for adult foraging penguins near Palmer Station, Antarctica #' -#' These data are copied and lightly modified from the `palmerpenguins` -#' `penguins` data. The only modification is removal of rows +#' These data are copied and lightly modified from the `penguins` data in +#' the [palmerpenguins](https://allisonhorst.github.io/palmerpenguins/) R +#' package. The only modification is removal of rows #' with missing data. The data include measurements for penguin species, #' island in Palmer Archipelago, size (flipper length, body mass, bill #' dimensions), and sex. diff --git a/R/orsf_R6.R b/R/orsf_R6.R index e23097ea..1e98a03c 100644 --- a/R/orsf_R6.R +++ b/R/orsf_R6.R @@ -275,7 +275,7 @@ ObliqueForest <- R6::R6Class( private$check_pred_aggregate(pred_aggregate) if(self$tree_type == 'survival') - private$check_pred_horizon(boundary_checks, pred_horizon) + private$check_pred_horizon(pred_horizon, boundary_checks) self$data <- new_data self$pred_horizon <- pred_horizon @@ -319,12 +319,6 @@ ObliqueForest <- R6::R6Class( }, - predict_internal = function(){ - - stop("this method should only be called from derived classes") - - }, - compute_vi = function(type_vi, oobag_fun, n_thread, @@ -383,8 +377,8 @@ ObliqueForest <- R6::R6Class( oobag, type_output){ - public_state <- list(data = self$data, - na_action = self$na_action) + public_state <- list(data = self$data, + na_action = self$na_action) private_state <- list(data_rows_complete = private$data_rows_complete) @@ -556,32 +550,6 @@ ObliqueForest <- R6::R6Class( pd_vals <- results - # denominator issue in this (i think) - # cpp_args = private$prep_cpp_args(x = private$x, - # y = private$y, - # w = private$w, - # importance_type = 'none', - # pred_type = pred_type, - # pred_aggregate = TRUE, - # pred_horizon = pred_horizon_ordered, - # oobag = oobag, - # oobag_eval_type = 'none', - # pred_mode = FALSE, - # pred_aggregate = TRUE, - # write_forest = FALSE, - # run_forest = TRUE, - # pd_type_R = switch(type_output, - # "smry" = 1L, - # "ice" = 2L), - # pd_x_vals = pred_spec_new, - # pd_x_cols = x_cols, - # pd_probs = prob_values, - # verbosity = 0) - # - # orsf_out <- do.call(orsf_cpp, args = cpp_args) - - # pd_vals <- orsf_out$pd_values - for(i in seq_along(pd_vals)){ pd_bind[[i]]$id_variable <- seq(nrow(pd_bind[[i]])) @@ -769,6 +737,116 @@ ObliqueForest <- R6::R6Class( }, + summarize_uni = function(n_variables = NULL, + pred_horizon = NULL, + pred_type = NULL, + importance_type = NULL){ + + # check incoming values if they were specified. + private$check_n_variables(n_variables) + private$check_pred_horizon(pred_horizon, boundary_checks = TRUE) + private$check_pred_type(pred_type, oobag = FALSE) + private$check_importance_type(importance_type) + + names_x <- private$data_names$x_original + + # use existing values if incoming ones were not specified + n_variables <- n_variables %||% length(names_x) + pred_horizon <- pred_horizon %||% self$pred_horizon + pred_type <- pred_type %||% self$pred_type + importance_type <- importance_type %||% self$importance_type + + # bindings for CRAN check + value <- NULL + level <- NULL + + # TODO: make this go away. Just sort alphabetically if no importance + if(importance_type == 'none' && is_empty(self$importance_type)) + stop("importance cannot be 'none' if object does not have variable", + " importance values.", call. = FALSE) + + vi <- switch( + importance_type, + 'anova' = orsf_vi_anova(self, group_factors = TRUE), + 'negate' = orsf_vi_negate(self, group_factors = TRUE), + 'permute' = orsf_vi_permute(self, group_factors = TRUE), + 'none' = NULL + ) + + bounds <- private$data_bounds + fctrs <- private$data_fctrs + n_obs <- self$n_obs + + names_vi <- names(vi) %||% names_x + + pred_spec <- list_init(names_vi)[seq(n_variables)] + + for(i in names(pred_spec)){ + + if(i %in% colnames(bounds)){ + + pred_spec[[i]] <- unique( + as.numeric(bounds[c('25%','50%','75%'), i]) + ) + + } else if (i %in% fctrs$cols) { + + pred_spec[[i]] <- fctrs$lvls[[i]] + + } + + } + + pd_output <- orsf_pd_oob(object = self, + pred_spec = pred_spec, + expand_grid = FALSE, + pred_type = pred_type, + prob_values = c(0.25, 0.50, 0.75), + pred_horizon = pred_horizon) + + fctrs_unordered <- c() + + # did the orsf have factor variables? + if(!is_empty(fctrs$cols)){ + fctrs_unordered <- fctrs$cols[!fctrs$ordr] + } + + # some cart-wheels here for backward compatibility. + f <- as.factor(pd_output$variable) + + name_rep <- rle(as.integer(f)) + + pd_output$importance <- rep(vi[levels(f)[name_rep$values]], + times = name_rep$lengths) + + pd_output[, value := fifelse(test = is.na(value), + yes = as.character(level), + no = round_magnitude(value))] + + # if a := is used inside a function with no DT[] before the end of the + # function, then the next time DT or print(DT) is typed at the prompt, + # nothing will be printed. A repeated DT or print(DT) will print. + # To avoid this: include a DT[] after the last := in your function. + pd_output[] + + setcolorder(pd_output, c('variable', + 'importance', + 'value', + 'mean', + 'medn', + 'lwr', + 'upr')) + + structure( + .Data = list(dt = pd_output, + pred_type = pred_type, + pred_horizon = pred_horizon), + class = 'orsf_summary_uni' + ) + + + }, + # getters get_names_x = function(ref_coded = FALSE){ @@ -1387,6 +1465,39 @@ ObliqueForest <- R6::R6Class( expected_length = 1) }, + + check_n_variables = function(n_variables = NULL){ + + # n_variables is not a field of ObliqueForest, + # so it is only checked as an incoming input. + + if(!is.null(n_variables)){ + + check_arg_type(arg_value = n_variables, + arg_name = 'n_variables', + expected_type = 'numeric') + + check_arg_is_integer(arg_value = n_variables, + arg_name = 'n_variables') + + check_arg_gteq(arg_value = n_variables, + arg_name = 'n_variables', + bound = 1) + + check_arg_lteq(arg_value = n_variables, + arg_name = 'n_variables', + bound = length(private$data_names$x_original), + append_to_msg = "(total number of predictors)") + + + check_arg_length(arg_value = n_variables, + arg_name = 'n_variables', + expected_length = 1) + + } + + }, + check_mtry = function(mtry = NULL){ input <- mtry %||% self$mtry @@ -2352,7 +2463,7 @@ ObliqueForestSurvival <- R6::R6Class( }, - check_pred_horizon = function(boundary_checks = TRUE, pred_horizon = NULL){ + check_pred_horizon = function(pred_horizon = NULL, boundary_checks = TRUE){ input <- pred_horizon %||% self$pred_horizon @@ -2451,6 +2562,7 @@ ObliqueForestSurvival <- R6::R6Class( self$tree_type <- "survival" + self$split_rule <- self$split_rule %||% 'logrank' self$pred_type <- self$pred_type %||% 'surv' self$split_min_stat <- self$split_min_stat %||% @@ -2501,7 +2613,7 @@ ObliqueForestSurvival <- R6::R6Class( if(is.null(self$pred_horizon)){ self$pred_horizon <- collapse::fmedian(y[, 1]) } else { - private$check_pred_horizon(boundary_checks = TRUE) + private$check_pred_horizon(self$pred_horizon, boundary_checks = TRUE) } private$check_leaf_min_events() diff --git a/R/orsf_summary.R b/R/orsf_summary.R index d8bd5eae..c867cc64 100644 --- a/R/orsf_summary.R +++ b/R/orsf_summary.R @@ -57,144 +57,20 @@ orsf_summarize_uni <- function(object, n_variables = NULL, pred_horizon = NULL, - pred_type = 'risk', - importance = 'negate', + pred_type = NULL, + importance = NULL, ...){ - # bindings for CRAN check - value <- NULL - level <- NULL - check_dots(list(...), .f = orsf_summarize_uni) check_arg_is(arg_value = object, arg_name = 'object', expected_class = 'ObliqueForest') - if(!is.null(n_variables)){ - - check_arg_type(arg_value = n_variables, - arg_name = 'n_variables', - expected_type = 'numeric') - - check_arg_is_integer(arg_value = n_variables, - arg_name = 'n_variables') - - check_arg_gteq(arg_value = n_variables, - arg_name = 'n_variables', - bound = 1) - - check_arg_lteq(arg_value = n_variables, - arg_name = 'n_variables', - bound = length(object$get_names_x()), - append_to_msg = "(total number of predictors)") - - - check_arg_length(arg_value = n_variables, - arg_name = 'n_variables', - expected_length = 1) - - } - - check_predict(object = object, - pred_horizon = pred_horizon, - pred_type = pred_type) - - if(is.null(pred_horizon)) pred_horizon <- object$pred_horizon - - if(importance == 'none' && is_empty(object$importance)) - stop("importance cannot be 'none' if object does not have variable", - " importance values.", call. = FALSE) - - check_orsf_inputs(importance = importance) - - if(importance == 'none') importance <- object$importance_type - - vi <- switch( - importance, - 'anova' = orsf_vi_anova(object, group_factors = TRUE), - 'negate' = orsf_vi_negate(object, group_factors = TRUE), - 'permute' = orsf_vi_permute(object, group_factors = TRUE) - ) - - if(is.null(n_variables)) n_variables <- length(vi) - - - x_numeric_key <- object$get_bounds() - - fctr_info <- object$get_fctr_info() - - n_obs <- object$n_obs - - pred_spec <- list_init(names(vi)[seq(n_variables)]) - - for(x_name in names(pred_spec)){ - - if(x_name %in% colnames(x_numeric_key)){ - - pred_spec[[x_name]] <- unique( - as.numeric(x_numeric_key[c('25%','50%','75%'), x_name]) - ) - - } else if (x_name %in% fctr_info$cols) { - - pred_spec[[x_name]] <- fctr_info$lvls[[x_name]] - - } - - } - - pd_output <- orsf_pd_oob(object = object, - pred_spec = pred_spec, - expand_grid = FALSE, - pred_type = pred_type, - prob_values = c(0.25, 0.50, 0.75), - pred_horizon = pred_horizon) - - fctrs_unordered <- c() - - # did the orsf have factor variables? - if(!is_empty(fctr_info$cols)){ - fctrs_unordered <- fctr_info$cols[!fctr_info$ordr] - } - - # some cart-wheels here for backward compatibility. - f <- as.factor(pd_output$variable) - - name_rep <- rle(as.integer(f)) - - pd_output$importance <- rep(vi[levels(f)[name_rep$values]], - times = name_rep$lengths) - - # pd_output$value <- ifelse(test = is.na(value), - # yes = as.character(level), - # no = round_magnitude(value)) - - pd_output[, value := fifelse(test = is.na(value), - yes = as.character(level), - no = round_magnitude(value))] - - # if a := is used inside a function with no DT[] before the end of the - # function, then the next time DT or print(DT) is typed at the prompt, - # nothing will be printed. A repeated DT or print(DT) will print. - # To avoid this: include a DT[] after the last := in your function. - pd_output[] - - setcolorder(pd_output, c('variable', - 'importance', - 'value', - 'mean', - 'medn', - 'lwr', - 'upr')) - - structure( - .Data = list(dt = pd_output, - pred_type = pred_type, - pred_horizon = pred_horizon), - class = 'orsf_summary_uni' - ) - + object$summarize_uni(n_variables = n_variables, + pred_horizon = pred_horizon, + pred_type = pred_type, + importance_type = importance) } diff --git a/man/penguins_orsf.Rd b/man/penguins_orsf.Rd index 1dcf9d0c..c7b84241 100644 --- a/man/penguins_orsf.Rd +++ b/man/penguins_orsf.Rd @@ -30,8 +30,9 @@ A tibble with 333 rows and 8 variables: penguins_orsf } \description{ -These data are copied and lightly modified from the \code{palmerpenguins} -\code{penguins} data. The only modification is removal of rows +These data are copied and lightly modified from the \code{penguins} data in +the \href{https://allisonhorst.github.io/palmerpenguins/}{palmerpenguins} R +package. The only modification is removal of rows with missing data. The data include measurements for penguin species, island in Palmer Archipelago, size (flipper length, body mass, bill dimensions), and sex. diff --git a/src/Forest.cpp b/src/Forest.cpp index 1dfb7a40..c1185728 100644 --- a/src/Forest.cpp +++ b/src/Forest.cpp @@ -603,7 +603,7 @@ mat Forest::predict(bool oobag) { aborted_threads = 0; if(n_thread == 1){ - // ensure safe usage of R functions + predict_single_thread(data.get(), oobag, result); } else { @@ -677,7 +677,7 @@ mat Forest::predict(bool oobag) { } // it's okay if we divide by 0 here. It makes the result NaN but - // that will be fixed when the results are post-processed in R/orsf.R + // that will be fixed when the results are cleaned in R result.each_col() /= oobag_denom; } else { @@ -686,6 +686,10 @@ mat Forest::predict(bool oobag) { } + if(pred_type == PRED_CLASS){ + predict_class(result); + } + return(result); } diff --git a/src/TreeClassification.cpp b/src/TreeClassification.cpp index c1912a79..3fe5644f 100644 --- a/src/TreeClassification.cpp +++ b/src/TreeClassification.cpp @@ -122,7 +122,10 @@ uword leaf_id = pred_leaf[it]; if(leaf_id == max_nodes) break; - pred_output.row(it) = leaf_summary[leaf_id]; + + pred_output.at(it, leaf_summary[leaf_id])++; + + // pred_output.row(it) = leaf_summary[leaf_id]; n_preds_made++; if(oobag) pred_denom[it]++; diff --git a/src/utility.cpp b/src/utility.cpp index dff5a42c..d46d53be 100644 --- a/src/utility.cpp +++ b/src/utility.cpp @@ -672,5 +672,19 @@ } + void predict_class(arma::mat& pred){ + + // modify column 0 + for(uword i = 0; i < pred.n_rows; ++i){ + pred.at(i, 0) = pred.row(i).index_max(); + } + + // drop the other colums + while(pred.n_cols > 1){ + pred.shed_col(1); + } + + } + } diff --git a/src/utility.h b/src/utility.h index 95982efd..c0fe330b 100644 --- a/src/utility.h +++ b/src/utility.h @@ -114,6 +114,8 @@ aorsf may be modified and distributed under the terms of the MIT license. arma::vec& beta_var, arma::mat& x_transforms); + void predict_class(arma::mat& pred); + } #endif /* UTILITY_H */ diff --git a/tests/testthat/test-orsf_pd.R b/tests/testthat/test-orsf_pd.R index 71272327..3551ddfa 100644 --- a/tests/testthat/test-orsf_pd.R +++ b/tests/testthat/test-orsf_pd.R @@ -57,9 +57,6 @@ test_that( ) funs <- list( - # ice_new = orsf_ice_new, - # ice_inb = orsf_ice_inb, - # ice_oob = orsf_ice_oob, pd_new = orsf_pd_new, pd_inb = orsf_pd_inb, pd_oob = orsf_pd_oob diff --git a/tests/testthat/test-orsf_predict.R b/tests/testthat/test-orsf_predict.R index 35ac3ede..1b653c6a 100644 --- a/tests/testthat/test-orsf_predict.R +++ b/tests/testthat/test-orsf_predict.R @@ -1,4 +1,6 @@ +pred_horizon <- c(1000, 2500) + test_preds_surv <- function(pred_type){ n_train <- nrow(pbc_train) @@ -129,8 +131,6 @@ test_preds_surv <- function(pred_type){ } -pred_horizon <- c(1000, 2500) - pred_objects_surv <- lapply(pred_types_surv, test_preds_surv) test_that( @@ -557,27 +557,6 @@ test_that( ) -# Just run locally. Possible memory leaks from units. -# test_that( -# desc = 'inconsistent units are detected', -# code = { -# -# suppressMessages(library(units)) -# pbc_units <- pbc_orsf -# units(pbc_units$age) <- 'years' -# -# pbc_test_units <- pbc_test -# units(pbc_test_units$age) <- 'days' -# -# fit <- orsf(formula = time + status ~ . - id, -# data = pbc_units, -# n_tree = n_tree_test) -# -# expect_error(predict(fit, new_data = pbc_test_units, pred_horizon = 1000), -# 'has unit') -# -# } -# ) test_that( desc = 'predictions dont require cols in same order as training data', @@ -598,59 +577,6 @@ test_that( ) -# test_that( -# 'units are vetted in testing data', -# code = { -# -# suppressMessages(library(units)) -# pbc_units_trn <- pbc_train -# pbc_units_tst <- pbc_test -# -# -# units(pbc_units_trn$time) <- 'days' -# units(pbc_units_trn$age) <- 'years' -# units(pbc_units_trn$bili) <- 'mg/dl' -# -# fit_units = orsf(formula = time + status ~ . - id, -# data = pbc_units_trn, -# n_tree = n_tree_test, -# oobag_pred_horizon = c(1000, 2500), -# tree_seeds = seeds_standard) -# -# expect_error( -# predict(fit_units, new_data = pbc_units_tst, pred_horizon = 1000), -# regexp = 'time, age, and bili' -# ) -# -# units(pbc_units_tst$time) <- 'years' -# units(pbc_units_tst$age) <- 'years' -# units(pbc_units_tst$bili) <- 'mg/dl' -# -# expect_error( -# predict(fit_units, new_data = pbc_units_tst, pred_horizon = 1000), -# regexp = 'time has unit d in the training data' -# ) -# -# units(pbc_units_tst$time) <- 'days' -# units(pbc_units_tst$age) <- 'years' -# units(pbc_units_tst$bili) <- 'mg/dl' -# -# expect_equal_leaf_summary(fit_units, pred_objects_surv$surv$fit) -# expect_equal_oobag_eval(fit_units, pred_objects_surv$surv$fit) -# -# units(pbc_units_tst$time) <- 'days' -# units(pbc_units_tst$age) <- 'years' -# units(pbc_units_tst$bili) <- 'mg/l' -# -# expect_error( -# predict(fit_units, new_data = pbc_units_tst, pred_horizon = 1000), -# regexp = 'bili has unit mg/dl in the training data' -# ) -# -# } -# -# ) - # Tests for passing missing data ---- na_index_age <- c(1, 4, 8) diff --git a/tests/testthat/test-orsf_summary.R b/tests/testthat/test-orsf_summary.R index 9e4e6ecd..90714ad0 100644 --- a/tests/testthat/test-orsf_summary.R +++ b/tests/testthat/test-orsf_summary.R @@ -52,9 +52,6 @@ no_miss_list <- function(l){ fi <- object$get_fctr_info() -#' @srrstats {G5.2} *Appropriate error behaviour is explicitly demonstrated through tests.* -#' @srrstats {G5.2b} *Tests demonstrate conditions which trigger error messages.* - test_that("output is normal", { diff --git a/tests/testthat/test-unit_info.R b/tests/testthat/test-unit_info.R index 1abfbe04..51429bf3 100644 --- a/tests/testthat/test-unit_info.R +++ b/tests/testthat/test-unit_info.R @@ -1,53 +1,86 @@ -# -# -# # handle 'units' variables -# suppressMessages(library(units)) -# -# pbc_units <- pbc_orsf -# -# units(pbc_units$time) <- 'days' -# units(pbc_units$age) <- 'years' -# units(pbc_units$bili) <- 'mg/dl' -# -# test_that("output has expected items", { -# -# ui <- unit_info(pbc_units, c('time', 'age', 'bili')) -# -# expect_equal( -# ui, -# list( -# time = list( -# numerator = "d", -# denominator = character(0), -# label = "d" -# ), -# age = list( -# numerator = "years", -# denominator = character(0), -# label = "years" -# ), -# bili = list( -# numerator = "mg", -# denominator = "dl", -# label = "mg/dl" -# ) -# ) -# ) -# -# expect_true(is_empty(unit_info(pbc_units, c()))) -# -# }) -# -# pbc_units_badclass <- pbc_units -# class(attr(pbc_units_badclass$bili, 'units')) <- 'bad_units' -# -# -# test_that('only symbolic units are allowed', { -# -# expect_error(unit_info(pbc_units_badclass, 'bili'), 'symbolic_units') -# -# }) -# -# -# -# + + +# handle 'units' variables. All of these tests are skipped +# on CRAN because for some reason when I load and use the +# units package it makes valgrind detect possible memory leaks. + +test_that("output has expected items", { + + skip_on_cran() + + suppressMessages(library(units)) + + pbc_units <- pbc_orsf + + units(pbc_units$time) <- 'days' + units(pbc_units$age) <- 'years' + units(pbc_units$bili) <- 'mg/dl' + ui <- unit_info(pbc_units, c('time', 'age', 'bili')) + + expect_equal( + ui, + list( + time = list( + numerator = "d", + denominator = character(0), + label = "d" + ), + age = list( + numerator = "years", + denominator = character(0), + label = "years" + ), + bili = list( + numerator = "mg", + denominator = "dl", + label = "mg/dl" + ) + ) + ) + + expect_true(is_empty(unit_info(pbc_units, c()))) + +}) + + +test_that('only symbolic units are allowed', { + + + skip_on_cran() + + suppressMessages(library(units)) + + pbc_units <- pbc_orsf + + units(pbc_units$bili) <- 'mg/dl' + + class(attr(pbc_units$bili, 'units')) <- 'bad_units' + + expect_error(unit_info(pbc_units, 'bili'), 'symbolic_units') + +}) + + +test_that( + desc = 'inconsistent units are detected', + code = { + + skip_on_cran() + + suppressMessages(library(units)) + + pbc_units <- pbc_orsf + units(pbc_units$age) <- 'years' + + pbc_test_units <- pbc_orsf + units(pbc_test_units$age) <- 'days' + + fit <- orsf(formula = time + status ~ . - id, + data = pbc_units, + n_tree = n_tree_test) + + expect_error(predict(fit, new_data = pbc_test_units, pred_horizon = 1000), + 'has unit') + + } +)