From d209842ccbb2a8d8d1c762ebaa591087c8cbd7d9 Mon Sep 17 00:00:00 2001 From: byron jaeger Date: Sun, 24 Dec 2023 13:47:01 -0500 Subject: [PATCH] better score for classif --- R/orsf_vint.R | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/R/orsf_vint.R b/R/orsf_vint.R index e3c6bf75..f377dfd3 100644 --- a/R/orsf_vint.R +++ b/R/orsf_vint.R @@ -97,7 +97,17 @@ orsf_vint <- function(object, pd$id_intr <- paste(pd$var_1_name, pd$var_2_name, sep = sep) - pd_split <- split(pd, pd$id_intr) + if(object$tree_type == 'classification'){ + pd[, mean := log(mean+0.01)] + pd[, class := paste0(class, "._aorsf.split_")] + } + + split_vars <- switch(object$tree_type, + "survival" = "id_intr", + "classification" = c("class", "id_intr"), + "regression" = "id_intr") + + pd_split <- split(pd, by = split_vars) # for cran . <- score <- var_1_value <- var_2_value <- NULL @@ -120,6 +130,20 @@ orsf_vint <- function(object, out <- data.table(interaction = names(pd_scores), score = as.numeric(pd_scores)) + if(object$tree_type == 'classification'){ + + out[, class := tstrsplit(interaction, + "\\.\\_aorsf\\.split\\_\\.", + keep = 1L)] + + out[, interaction := tstrsplit(interaction, + "\\.\\_aorsf\\.split\\_\\.", + keep = 2L)] + + out <- out[, .(score = mean(score)), by = c('interaction')] + + } + out[order(-score)] }