Skip to content

Commit

Permalink
Add code examples for wrapper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
nanxstats committed Apr 27, 2024
1 parent 9ffae2b commit ce0e728
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 17 deletions.
23 changes: 12 additions & 11 deletions R/stackgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ stackgbm <- function(x, y, params, n_folds = 5L, seed = 42, verbose = TRUE) {
x_glm <- matrix(NA, nrow = nrow_x, ncol = 3L)
colnames(x_glm) <- c("xgb", "lgb", "cat")

# xgboost
# xgboost ----
pb <- progress_bar$new(
format = " fitting xgboost model [:bar] :percent in :elapsed",
total = n_folds, clear = FALSE, width = 60
Expand Down Expand Up @@ -68,7 +68,7 @@ stackgbm <- function(x, y, params, n_folds = 5L, seed = 42, verbose = TRUE) {
x_glm[index_xgb == i, "xgb"] <- predict(fit, xtest)
}

# lightgbm
# lightgbm ----
pb <- progress_bar$new(
format = " fitting lightgbm model [:bar] :percent in :elapsed",
total = n_folds, clear = FALSE, width = 60
Expand Down Expand Up @@ -100,7 +100,7 @@ stackgbm <- function(x, y, params, n_folds = 5L, seed = 42, verbose = TRUE) {
x_glm[index_lgb == i, "lgb"] <- predict(fit, xtest)
}

# catboost
# catboost ----
pb <- progress_bar$new(
format = " fitting catboost model [:bar] :percent in :elapsed",
total = n_folds, clear = FALSE, width = 60
Expand Down Expand Up @@ -130,17 +130,18 @@ stackgbm <- function(x, y, params, n_folds = 5L, seed = 42, verbose = TRUE) {
x_glm[index_cat == i, "cat"] <- catboost_predict(fit, pool = test_pool, prediction_type = "Probability")
}

# logistic regression
# Logistic regression ----
df <- as.data.frame(cbind(y, x_glm))
names(df)[1] <- "y"
model_glm <- glm(y ~ ., data = df, family = binomial())

lst <- list(
"model_xgb" = model_xgb,
"model_lgb" = model_lgb,
"model_cat" = model_cat,
"model_glm" = model_glm
structure(
list(
"model_xgb" = model_xgb,
"model_lgb" = model_lgb,
"model_cat" = model_cat,
"model_glm" = model_glm
),
class = "stackgbm"
)
class(lst) <- "stackgbm"
lst
}
28 changes: 27 additions & 1 deletion R/wrappers_lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,33 @@
#' @export
#'
#' @examplesIf is_installed_lightgbm()
#' # Example code
#' sim_data <- msaenet::msaenet.sim.binomial(
#' n = 100,
#' p = 10,
#' rho = 0.6,
#' coef = rnorm(5, mean = 0, sd = 10),
#' snr = 1,
#' p.train = 0.8,
#' seed = 42
#' )
#'
#' fit <- suppressWarnings(
#' lightgbm_train(
#' data = sim_data$x.tr,
#' label = sim_data$y.tr,
#' params = list(
#' objective = "binary",
#' learning_rate = 0.1,
#' num_iterations = 100,
#' max_depth = 3,
#' num_leaves = 2^3 - 1,
#' num_threads = 1
#' ),
#' verbose = -1
#' )
#' )
#'
#' fit
lightgbm_train <- function(data, label, params, ...) {
rlang::check_installed("lightgbm", reason = "to train the model")
cl <- rlang::call2(
Expand Down
41 changes: 39 additions & 2 deletions R/wrappers_xgboost.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,20 @@
#' @export
#'
#' @examplesIf is_installed_xgboost()
#' # Example code
#' sim_data <- msaenet::msaenet.sim.binomial(
#' n = 100,
#' p = 10,
#' rho = 0.6,
#' coef = rnorm(5, mean = 0, sd = 10),
#' snr = 1,
#' p.train = 0.8,
#' seed = 42
#' )
#'
#' x_train <- xgboost_dmatrix(sim_data$x.tr, label = sim_data$y.tr)
#' x_train
#' x_test <- xgboost_dmatrix(sim_data$x.te)
#' x_test
xgboost_dmatrix <- function(data, label = NULL, ...) {
rlang::check_installed("xgboost", reason = "to create a dataset")
cl <- if (is.null(label)) {
Expand All @@ -32,7 +45,31 @@ xgboost_dmatrix <- function(data, label = NULL, ...) {
#' @export
#'
#' @examplesIf is_installed_xgboost()
#' # Example code
#' sim_data <- msaenet::msaenet.sim.binomial(
#' n = 100,
#' p = 10,
#' rho = 0.6,
#' coef = rnorm(5, mean = 0, sd = 10),
#' snr = 1,
#' p.train = 0.8,
#' seed = 42
#' )
#'
#' x_train <- xgboost_dmatrix(sim_data$x.tr, label = sim_data$y.tr)
#'
#' fit <- xgboost_train(
#' params = list(
#' objective = "binary:logistic",
#' eval_metric = "auc",
#' max_depth = 3,
#' eta = 0.1
#' ),
#' data = x_train,
#' nrounds = 100,
#' nthread = 1
#' )
#'
#' fit
xgboost_train <- function(params, data, nrounds, ...) {
rlang::check_installed("xgboost", reason = "to train the model")
cl <- rlang::call2(
Expand Down
28 changes: 27 additions & 1 deletion man/lightgbm_train.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 14 additions & 1 deletion man/xgboost_dmatrix.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

26 changes: 25 additions & 1 deletion man/xgboost_train.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ce0e728

Please sign in to comment.