Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass additional arguments to customized functions in identifying linear combinations of predictors (suggestion + issue) #67

Closed
AbubakerSuliman opened this issue Sep 30, 2024 · 9 comments

Comments

@AbubakerSuliman
Copy link

AbubakerSuliman commented Sep 30, 2024

Dear Prof @bcjaeger thank you so much for such a great package.

First, Here is my two cents on improving the speed of method='net'.

In penalized_cph.R I'd suggest looping through the unique values instead of the complete list. Fit a custom penalized Cox regression using unique values on pbc_orsf results in 10% to 30% speed reduction.

indxs = c(1, which(diff(fit$df)>=1)+1)
  for(i in indxs){
    if(fit$df[i] >= target_df || i == tail(indxs, 1)){
      return(matrix(fit$beta[, i, drop=TRUE], ncol = 1))
    }
  }

Second, I'm exploring different methods to create linear combinations of predictors; however, I can't pass an additional argument (e.g. target_df) or access it from the parent environment in the case of a custom function. The following function will throw an error. I would appreciate any ideas on how to solve this issue.

f_net <- function(x_node, y_node, w_node, add_arg) {
  add_arg
  colnames(y_node) <- c('time', 'status')
  colnames(x_node) <- paste("x", seq(ncol(x_node)), sep = '')
  
  data <- as.data.frame(cbind(y_node, x_node))
  
  if(nrow(data) <= 10)
    return(matrix(runif(ncol(x_node)), ncol = 1))
  
  suppressWarnings(
    fit <- try(
      glmnet::glmnet(x = x_node,
                     y = survival::Surv(data$time, data$status),
                     weights = w_node,
                     alpha = 0.5,
                     family = "cox"),
      silent = TRUE
    )
  )
  
  if(aorsf:::is_error(fit)){
    return(matrix(runif(ncol(x_node)), nrow=ncol(x_node), ncol=1))
  }

  indxs = c(1, which(diff(fit$df)>=1)+1)
  for(i in indxs){
    if(fit$df[i] >= 5 || i == tail(indxs, 1)){
      return(matrix(fit$beta[, i, drop=TRUE], ncol = 1))
    }
  }
  
}
@AbubakerSuliman AbubakerSuliman changed the title Pass aditional arguments to customized functions in identify linear combinations of predictors (suggesiton + issue) Pass additional arguments to customized functions in identifying linear combinations of predictors (suggestion + issue) Sep 30, 2024
@bcjaeger
Copy link
Collaborator

bcjaeger commented Oct 4, 2024

Thank you! If I could get the time, I'd want to write routines in C++ that mimic glmnet so that we wouldn't have to call the R function.

For sending customized input values to the R function in aorsf, why not just define multiple functions that each use different hard coded values of target_df? It's logistically complicated to allow general objects to be passed into these functions b/c we'd have to know what to declare those objects as in C++

@AbubakerSuliman
Copy link
Author

AbubakerSuliman commented Oct 10, 2024

Thanks for the prompt response. This is well noted and received.

I have a new issue with using a custom function in mlr3 as control_type can only accept one of p_fct(levels = c("fast", "cph", "net"), default = "fast", tags = "train") (R/learner_aorsf_surv_aorsf.R).
Looking at the code, it seems possible, but I feel you have a technical/design reason not to include it.
I'd appreciate any help/ideas with this issue.

Update 1:
Here is a first try to modify the aorsf learner in mlr3extralearners after quick readings

ps = ps(
...
control_type = p_fct(levels = c("fast", "cph", "net", "custom"), default = "fast", tags = "train"),
control_custom_fun = p_uty(custom_check = function(x) checkmate::checkFunction(x, nargs = 3), 
                                              depends = control_type == "custom", tags = "train"),
...
)
 ...
,
 "custom" = {
   aorsf::orsf_control_survival(
     method = pv$control_custom_fun
   )
 }
)
# these parameters are used to organize the control arguments
# above but are not used directly by aorsf::orsf(), so:
pv = remove_named(pv, c("control_type",
                       ...,
                       "control_custom_fun"))

It seems to be working fine, but I have an issue with importance()

> learner$importance()
x8 x7 x6 x5 x4 x3 x2 x1
 0  0  0  0  0  0  0  0 

@bcjaeger
Copy link
Collaborator

Here is a first try to modify the aorsf learner in mlr3extralearners after quick readings

This is awesome! Thank you for writing this code. I wonder if we could request the custom method be added to the aorsf learner in mlr3extralearners once we've figured out the importance issue. Do you also get importance values of 0 for this aorsf model when you fit it using aorsf::orsf?

@AbubakerSuliman
Copy link
Author

I wonder if we could request the custom method be added to the aorsf learner in mlr3extralearners once we've figured out the importance issue

Why not? It would allow benchmarking of anything.

Do you also get importance values of 0 for this aorsf model when you fit it using aorsf::orsf?

Yes, interestingly it works fine with importance = 'negate'/'permute'

@bcjaeger
Copy link
Collaborator

Thank you! Could you clarify the second item for me? Did aorsf::orsf() give you all 0's for importance values, or did it work fine (i.e., giving non-zero important values).

For the PR, would you like to take the lead by initiating an issue on mlr3extralearners? If you'd like, I could do this, but I am slowed down by other obligations and I also want to make sure you get credited for the awesome work you've done.

@AbubakerSuliman
Copy link
Author

Thank you! Could you clarify the second item for me? Did aorsf::orsf() give you all 0's for importance values, or did it work fine (i.e., giving non-zero important values).

aorsf::orsf() with a custom function works fine when I calculate importance using "negate" or "permute"; however, it fails when importance uses "anova". Here a MWE

library(aorsf)
f_rando <- function(x_node, y_node, w_node){
  matrix(runif(ncol(x_node)), ncol=1) 
}
fit_rando_anova <- orsf(pbc_orsf,
                  Surv(time, status) ~ . - id,
                  control = orsf_control_survival(method = f_rando),
                  importance = "anova",
                  tree_seeds = 329)
fit_rando_negate <- orsf(pbc_orsf,
                        Surv(time, status) ~ . - id,
                        control = orsf_control_survival(method = f_rando),
                        importance = "negate",
                        tree_seeds = 329)
fit_rando_permute <- orsf(pbc_orsf,
                         Surv(time, status) ~ . - id,
                         control = orsf_control_survival(method = f_rando),
                         importance = "permute",
                         tree_seeds = 329)

fit_rando_anova$importance
   stage  protime platelet     trig      ast alk.phos   copper  albumin     chol     bili 
       0        0        0        0        0        0        0        0        0        0 
   edema  spiders   hepato  ascites      sex      age      trt 
       0        0        0        0        0        0        0 

fit_rando_negate$importance
        bili       copper      protime        stage          age          ast          sex 
 0.061821928  0.061275078  0.045544739  0.039651143  0.039253153  0.027136941  0.019333503 
      hepato         chol      spiders         trig     alk.phos      ascites      albumin 
 0.015253945  0.014237792  0.010395895  0.010026580  0.009735672  0.009062175  0.007236837 
       edema          trt     platelet 
 0.005684531  0.003000386 -0.002974909 

fit_rando_permute$importance
       copper          bili       protime           age         stage           ast          chol 
 0.0288431287  0.0275947301  0.0235931010  0.0216206097  0.0174380106  0.0136475128  0.0076317414 
       hepato       spiders       albumin          trig       ascites         edema      alk.phos 
 0.0060957151  0.0060563554  0.0052816286  0.0050055876  0.0049497113  0.0028515817  0.0019844534 
          sex           trt      platelet 
-0.0005690087 -0.0020921897 -0.0030092387 

Regarding the PR, many thanks for the kind words. Sure, I will start the PR soon.

@bcjaeger
Copy link
Collaborator

Ahh, I see, that makes sense. ANOVA importance requires calculation of p-values so I didn't even attempt to do anova importance when a custom function is used to get linear combinations of predictors. I think perhaps aorsf::orsf() should throw an error if someone uses a custom function with ANOVA importance to prevent this confusing result - do you think that would be helpful?

@AbubakerSuliman
Copy link
Author

AbubakerSuliman commented Oct 17, 2024

do you think that would be helpful?

Of course, I wondered why you don't allow ANOVA for custom methods, and then I read the following from aorsf Github main page "ANOVA is very efficient computationally, but may not be as effective as permutation or negation in terms of selecting signal over noise variables."
So yes, an error message would be enough.

I opened a PR for the custom method here.

Finally, feel free to close this issue.

@bcjaeger
Copy link
Collaborator

Thank you! I should update that main page to also mention we can only compute anova importance if the linear combination method allows us to compute p-values for the variables that are being combined

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants