From eaca11811221066905b2c66996699cbacd0d42dc Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Thu, 27 Jul 2023 13:20:41 +0100 Subject: [PATCH] Support for 3-way validation split interface (#701) * support `initial_validation_split` objects in `last_fit()` * add `use_validation_set` arg * rename arg to `add_validation_set` * move out of main function for readability * fix namespace * update remote to main branch --- DESCRIPTION | 2 +- NEWS.md | 2 ++ R/fit_best.R | 31 +++++++++++++++- R/last_fit.R | 51 +++++++++++++++++++++++--- man/fit_best.Rd | 16 ++++++++- man/last_fit.Rd | 14 ++++++-- tests/testthat/test-fit_best.R | 65 ++++++++++++++++++++++++++++++++++ tests/testthat/test-last-fit.R | 65 ++++++++++++++++++++++++++++++++++ 8 files changed, 235 insertions(+), 11 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 10b719f2c..164179589 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -31,7 +31,7 @@ Imports: purrr (>= 1.0.0), recipes (>= 1.0.4), rlang (>= 1.1.0), - rsample (>= 1.0.0), + rsample (>= 1.1.1.9001), tibble (>= 3.1.0), tidyr (>= 1.2.0), tidyselect (>= 1.1.2), diff --git a/NEWS.md b/NEWS.md index fc1baccbc..1e4dc443c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,6 +8,8 @@ * A method for rsample's `int_pctl()` function that will compute percentile confidence intervals on performance metrics for objects produced by `fit_resamples()`, `tune_*()`, and `last_fit()`. +* `last_fit()` now works with the 3-way validation split objects from `rsample::initial_validation_split()`. `last_fit()` and `fit_best()` now have a new argument `add_validation_set` to include or exclude the validation set in the dataset used to fit the model (#701). + # tune 1.1.1 * Fixed a bug introduced in tune 1.1.0 in `collect_()` functions where the diff --git a/R/fit_best.R b/R/fit_best.R index be786f4da..d14f86fb3 100644 --- a/R/fit_best.R +++ b/R/fit_best.R @@ -14,6 +14,12 @@ #' If `NULL`, this argument will be set to #' [`select_best(metric)`][tune::select_best.tune_results]. #' @param verbose A logical for printing logging. +#' @param add_validation_set When the resamples embedded in `x` are a split into +#' training set and validation set, should the validation set be included in the +#' data set used to train the model. If not, only the training set is used. If +#' `NULL`, the validation set is not used for resamples originating from +#' [rsample::validation_set()] while it is used for resamples originating +#' from [rsample::validation_split()]. #' @param ... Not currently used. #' @details #' This function is a shortcut for the manual steps of: @@ -84,6 +90,7 @@ fit_best.tune_results <- function(x, metric = NULL, parameters = NULL, verbose = FALSE, + add_validation_set = NULL, ...) { if (length(list(...))) { cli::cli_abort(c("x" = "The `...` are not used by this function.")) @@ -120,7 +127,29 @@ fit_best.tune_results <- function(x, # ---------------------------------------------------------------------------- - dat <- x$splits[[1]]$data + if (inherits(x$splits[[1]], "val_split")) { + if (is.null(add_validation_set)) { + rset_info <- attr(x, "rset_info") + originate_from_3way_split <- rset_info$att$origin_3way %||% FALSE + if (originate_from_3way_split) { + add_validation_set <- FALSE + } else { + add_validation_set <- TRUE + } + } + if (add_validation_set) { + dat <- x$splits[[1]]$data + } else { + dat <- rsample::training(x$splits[[1]]) + } + } else { + if (!is.null(add_validation_set)) { + rlang::warn( + "The option `add_validation_set` is being ignored because the resampling object does not include a validation set." + ) + } + dat <- x$splits[[1]]$data + } if (verbose) { cli::cli_inform(c("i" = "Fitting using {nrow(dat)} data points...")) } diff --git a/R/last_fit.R b/R/last_fit.R index 31ea0294d..d721ae3c0 100644 --- a/R/last_fit.R +++ b/R/last_fit.R @@ -12,7 +12,8 @@ #' @param preprocessor A traditional model formula or a recipe created using #' [recipes::recipe()]. #' -#' @param split An `rsplit` object created from [rsample::initial_split()]. +#' @param split An `rsplit` object created from [rsample::initial_split()] or +#' [rsample::initial_validation_split()]. #' #' @param metrics A [yardstick::metric_set()], or `NULL` to compute a standard #' set of metrics. @@ -25,6 +26,11 @@ #' values should be non-negative and should probably be no greater then the #' largest event time in the training set. #' +#' @param add_validation_set For 3-way splits into training, validation, and test +#' set via [rsample::initial_validation_split()], should the validation set be +#' included in the data set used to train the model. If not, only the training +#' set is used. +#' #' @param ... Currently unused. #' #' @details @@ -113,7 +119,8 @@ last_fit.model_fit <- function(object, ...) { #' @export #' @rdname last_fit last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL, - control = control_last_fit(), eval_time = NULL) { + control = control_last_fit(), eval_time = NULL, + add_validation_set = FALSE) { if (rlang::is_missing(preprocessor) || !is_preprocessor(preprocessor)) { rlang::abort(paste( "To tune a model spec, you must preprocess", @@ -133,19 +140,20 @@ last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL wflow <- add_formula(wflow, preprocessor) } - last_fit_workflow(wflow, split, metrics, control, eval_time) + last_fit_workflow(wflow, split, metrics, control, eval_time, add_validation_set) } #' @rdname last_fit #' @export last_fit.workflow <- function(object, split, ..., metrics = NULL, - control = control_last_fit(), eval_time = NULL) { + control = control_last_fit(), eval_time = NULL, + add_validation_set = FALSE) { empty_ellipses(...) control <- parsnip::condense_control(control, control_last_fit()) - last_fit_workflow(object, split, metrics, control, eval_time) + last_fit_workflow(object, split, metrics, control, eval_time, add_validation_set) } @@ -154,6 +162,7 @@ last_fit_workflow <- function(object, metrics, control, eval_time = NULL, + add_validation_set = FALSE, ..., call = rlang::caller_env()) { rlang::check_dots_empty() @@ -166,6 +175,9 @@ last_fit_workflow <- function(object, ) } + if (inherits(split, "initial_validation_split")) { + split <- prepare_validation_split(split, add_validation_set) + } splits <- list(split) resamples <- rsample::manual_rset(splits, ids = "train/test split") @@ -190,3 +202,32 @@ last_fit_workflow <- function(object, .stash_last_result(res) res } + + +prepare_validation_split <- function(split, add_validation_set){ + if (add_validation_set) { + # equivalent to (unexported) rsample:::rsplit() without checks + split <- structure( + list( + data = split$data, + in_id = c(split$train_id, split$val_id), + out_id = NA + ), + class = "rsplit" + ) + } else { + id_train_test <- seq_len(nrow(split$data))[-sort(split$val_id)] + id_train <- match(split$train_id, id_train_test) + + split <- structure( + list( + data = split$data[-sort(split$val_id), , drop = FALSE], + in_id = id_train, + out_id = NA + ), + class = "rsplit" + ) + } + + split +} diff --git a/man/fit_best.Rd b/man/fit_best.Rd index aee4c4bf7..e687e102e 100644 --- a/man/fit_best.Rd +++ b/man/fit_best.Rd @@ -10,7 +10,14 @@ fit_best(x, ...) \method{fit_best}{default}(x, ...) -\method{fit_best}{tune_results}(x, metric = NULL, parameters = NULL, verbose = FALSE, ...) +\method{fit_best}{tune_results}( + x, + metric = NULL, + parameters = NULL, + verbose = FALSE, + add_validation_set = NULL, + ... +) } \arguments{ \item{x}{The results of class \code{tune_results} (coming from functions such as @@ -29,6 +36,13 @@ If \code{NULL}, this argument will be set to \code{\link[=select_best.tune_results]{select_best(metric)}}.} \item{verbose}{A logical for printing logging.} + +\item{add_validation_set}{When the resamples embedded in \code{x} are a split into +training set and validation set, should the validation set be included in the +data set used to train the model. If not, only the training set is used. If +\code{NULL}, the validation set is not used for resamples originating from +\code{\link[rsample:validation_set]{rsample::validation_set()}} while it is used for resamples originating +from \code{\link[rsample:validation_split]{rsample::validation_split()}}.} } \value{ A fitted workflow. diff --git a/man/last_fit.Rd b/man/last_fit.Rd index 740b162b2..6639eec03 100644 --- a/man/last_fit.Rd +++ b/man/last_fit.Rd @@ -15,7 +15,8 @@ last_fit(object, ...) ..., metrics = NULL, control = control_last_fit(), - eval_time = NULL + eval_time = NULL, + add_validation_set = FALSE ) \method{last_fit}{workflow}( @@ -24,7 +25,8 @@ last_fit(object, ...) ..., metrics = NULL, control = control_last_fit(), - eval_time = NULL + eval_time = NULL, + add_validation_set = FALSE ) } \arguments{ @@ -38,7 +40,8 @@ have been marked with \link[hardhat:tune]{tune()}, their values must be \item{preprocessor}{A traditional model formula or a recipe created using \code{\link[recipes:recipe]{recipes::recipe()}}.} -\item{split}{An \code{rsplit} object created from \code{\link[rsample:initial_split]{rsample::initial_split()}}.} +\item{split}{An \code{rsplit} object created from \code{\link[rsample:initial_split]{rsample::initial_split()}} or +\code{\link[rsample:initial_validation_split]{rsample::initial_validation_split()}}.} \item{metrics}{A \code{\link[yardstick:metric_set]{yardstick::metric_set()}}, or \code{NULL} to compute a standard set of metrics.} @@ -50,6 +53,11 @@ process.} metrics should be computed (e.g. the time-dependent ROC curve, etc). The values should be non-negative and should probably be no greater then the largest event time in the training set.} + +\item{add_validation_set}{For 3-way splits into training, validation, and test +set via \code{\link[rsample:initial_validation_split]{rsample::initial_validation_split()}}, should the validation set be +included in the data set used to train the model. If not, only the training +set is used.} } \value{ A single row tibble that emulates the structure of \code{fit_resamples()}. diff --git a/tests/testthat/test-fit_best.R b/tests/testthat/test-fit_best.R index 5e8995e0f..f598f6550 100644 --- a/tests/testthat/test-fit_best.R +++ b/tests/testthat/test-fit_best.R @@ -63,3 +63,68 @@ test_that("fit_best", { fit_best(ames_iter_search) ) }) + +test_that("fit_best() works with validation split: 3-way split", { + skip_if_not_installed("kknn") + skip_if_not_installed("modeldata") + data(ames, package = "modeldata", envir = rlang::current_env()) + + set.seed(23598723) + initial_val_split <- rsample::initial_validation_split(ames) + val_set <- validation_set(initial_val_split) + + f <- Sale_Price ~ Gr_Liv_Area + Year_Built + knn_mod <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression") + wflow <- workflow(f, knn_mod) + + tune_res <- tune_grid( + wflow, + grid = tibble(neighbors = c(1, 5)), + resamples = val_set, + control = control_grid(save_workflow = TRUE) + ) %>% suppressWarnings() + set.seed(3) + fit_on_train <- fit_best(tune_res) + pred <- predict(fit_on_train, testing(initial_val_split)) + + set.seed(3) + exp_fit_on_train <- nearest_neighbor(neighbors = 5) %>% + set_mode("regression") %>% + fit(f, training(initial_val_split)) + exp_pred <- predict(exp_fit_on_train, testing(initial_val_split)) + + expect_equal(pred, exp_pred) +}) + +test_that("fit_best() works with validation split: 2x 2-way splits", { + skip_if_not_installed("kknn") + skip_if_not_installed("modeldata") + data(ames, package = "modeldata", envir = rlang::current_env()) + + set.seed(23598723) + split <- rsample::initial_split(ames) + train_and_val <- training(split) + val_set <- rsample::validation_split(train_and_val) + + f <- Sale_Price ~ Gr_Liv_Area + Year_Built + knn_mod <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression") + wflow <- workflow(f, knn_mod) + + tune_res <- tune_grid( + wflow, + grid = tibble(neighbors = c(1, 5)), + resamples = val_set, + control = control_grid(save_workflow = TRUE) + ) + set.seed(3) + fit_on_train_and_val <- fit_best(tune_res) + pred <- predict(fit_on_train_and_val, testing(split)) + + set.seed(3) + exp_fit_on_train_and_val <- nearest_neighbor(neighbors = 5) %>% + set_mode("regression") %>% + fit(f, train_and_val) + exp_pred <- predict(exp_fit_on_train_and_val, testing(split)) + + expect_equal(pred, exp_pred) +}) diff --git a/tests/testthat/test-last-fit.R b/tests/testthat/test-last-fit.R index db34440e3..2595ec59f 100644 --- a/tests/testthat/test-last-fit.R +++ b/tests/testthat/test-last-fit.R @@ -163,3 +163,68 @@ test_that("`last_fit()` when objects need tuning", { expect_snapshot_error(last_fit(wflow_2, split)) expect_snapshot_error(last_fit(wflow_3, split)) }) + +test_that("last_fit() excludes validation set for initial_validation_split objects", { + skip_if_not_installed("modeldata") + data(ames, package = "modeldata", envir = rlang::current_env()) + + set.seed(23598723) + split <- rsample::initial_validation_split(ames) + + f <- Sale_Price ~ Gr_Liv_Area + Year_Built + lm_fit <- lm(f, data = rsample::training(split)) + test_pred <- predict(lm_fit, rsample::testing(split)) + rmse_test <- yardstick::rsq_vec(rsample::testing(split) %>% pull(Sale_Price), test_pred) + + res <- parsnip::linear_reg() %>% + parsnip::set_engine("lm") %>% + last_fit(f, split) + + expect_equal(res, .Last.tune.result) + + expect_equal( + coef(extract_fit_engine(res$.workflow[[1]])), + coef(lm_fit), + ignore_attr = TRUE + ) + expect_equal(res$.metrics[[1]]$.estimate[[2]], rmse_test) + expect_equal(res$.predictions[[1]]$.pred, unname(test_pred)) + expect_true(res$.workflow[[1]]$trained) + expect_equal( + nrow(predict(res$.workflow[[1]], rsample::testing(split))), + nrow(rsample::testing(split)) + ) +}) + +test_that("last_fit() can include validation set for initial_validation_split objects", { + skip_if_not_installed("modeldata") + data(ames, package = "modeldata", envir = rlang::current_env()) + + set.seed(23598723) + split <- rsample::initial_validation_split(ames) + + f <- Sale_Price ~ Gr_Liv_Area + Year_Built + train_val <- rbind(rsample::training(split), rsample::validation(split)) + lm_fit <- lm(f, data = train_val) + test_pred <- predict(lm_fit, rsample::testing(split)) + rmse_test <- yardstick::rsq_vec(rsample::testing(split) %>% pull(Sale_Price), test_pred) + + res <- parsnip::linear_reg() %>% + parsnip::set_engine("lm") %>% + last_fit(f, split, add_validation_set = TRUE) + + expect_equal(res, .Last.tune.result) + + expect_equal( + coef(extract_fit_engine(res$.workflow[[1]])), + coef(lm_fit), + ignore_attr = TRUE + ) + expect_equal(res$.metrics[[1]]$.estimate[[2]], rmse_test) + expect_equal(res$.predictions[[1]]$.pred, unname(test_pred)) + expect_true(res$.workflow[[1]]$trained) + expect_equal( + nrow(predict(res$.workflow[[1]], rsample::testing(split))), + nrow(rsample::testing(split)) + ) +})