diff --git a/R/orsf_vint.R b/R/orsf_vint.R index e3c6bf75..f377dfd3 100644 --- a/R/orsf_vint.R +++ b/R/orsf_vint.R @@ -97,7 +97,17 @@ orsf_vint <- function(object, pd$id_intr <- paste(pd$var_1_name, pd$var_2_name, sep = sep) - pd_split <- split(pd, pd$id_intr) + if(object$tree_type == 'classification'){ + pd[, mean := log(mean+0.01)] + pd[, class := paste0(class, "._aorsf.split_")] + } + + split_vars <- switch(object$tree_type, + "survival" = "id_intr", + "classification" = c("class", "id_intr"), + "regression" = "id_intr") + + pd_split <- split(pd, by = split_vars) # for cran . <- score <- var_1_value <- var_2_value <- NULL @@ -120,6 +130,20 @@ orsf_vint <- function(object, out <- data.table(interaction = names(pd_scores), score = as.numeric(pd_scores)) + if(object$tree_type == 'classification'){ + + out[, class := tstrsplit(interaction, + "\\.\\_aorsf\\.split\\_\\.", + keep = 1L)] + + out[, interaction := tstrsplit(interaction, + "\\.\\_aorsf\\.split\\_\\.", + keep = 2L)] + + out <- out[, .(score = mean(score)), by = c('interaction')] + + } + out[order(-score)] }