Skip to content

Commit

Permalink
Support for 3-way validation split interface (#701)
Browse files Browse the repository at this point in the history
* support `initial_validation_split` objects in `last_fit()`

* add `use_validation_set` arg

* rename arg to `add_validation_set`

* move out of main function for readability

* fix namespace

* update remote to main branch
  • Loading branch information
hfrick authored Jul 27, 2023
1 parent c58bfb1 commit eaca118
Show file tree
Hide file tree
Showing 8 changed files with 235 additions and 11 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Imports:
purrr (>= 1.0.0),
recipes (>= 1.0.4),
rlang (>= 1.1.0),
rsample (>= 1.0.0),
rsample (>= 1.1.1.9001),
tibble (>= 3.1.0),
tidyr (>= 1.2.0),
tidyselect (>= 1.1.2),
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

* A method for rsample's `int_pctl()` function that will compute percentile confidence intervals on performance metrics for objects produced by `fit_resamples()`, `tune_*()`, and `last_fit()`.

* `last_fit()` now works with the 3-way validation split objects from `rsample::initial_validation_split()`. `last_fit()` and `fit_best()` now have a new argument `add_validation_set` to include or exclude the validation set in the dataset used to fit the model (#701).

# tune 1.1.1

* Fixed a bug introduced in tune 1.1.0 in `collect_()` functions where the
Expand Down
31 changes: 30 additions & 1 deletion R/fit_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
#' If `NULL`, this argument will be set to
#' [`select_best(metric)`][tune::select_best.tune_results].
#' @param verbose A logical for printing logging.
#' @param add_validation_set When the resamples embedded in `x` are a split into
#' training set and validation set, should the validation set be included in the
#' data set used to train the model. If not, only the training set is used. If
#' `NULL`, the validation set is not used for resamples originating from
#' [rsample::validation_set()] while it is used for resamples originating
#' from [rsample::validation_split()].
#' @param ... Not currently used.
#' @details
#' This function is a shortcut for the manual steps of:
Expand Down Expand Up @@ -84,6 +90,7 @@ fit_best.tune_results <- function(x,
metric = NULL,
parameters = NULL,
verbose = FALSE,
add_validation_set = NULL,
...) {
if (length(list(...))) {
cli::cli_abort(c("x" = "The `...` are not used by this function."))
Expand Down Expand Up @@ -120,7 +127,29 @@ fit_best.tune_results <- function(x,

# ----------------------------------------------------------------------------

dat <- x$splits[[1]]$data
if (inherits(x$splits[[1]], "val_split")) {
if (is.null(add_validation_set)) {
rset_info <- attr(x, "rset_info")
originate_from_3way_split <- rset_info$att$origin_3way %||% FALSE
if (originate_from_3way_split) {
add_validation_set <- FALSE
} else {
add_validation_set <- TRUE
}
}
if (add_validation_set) {
dat <- x$splits[[1]]$data
} else {
dat <- rsample::training(x$splits[[1]])
}
} else {
if (!is.null(add_validation_set)) {
rlang::warn(
"The option `add_validation_set` is being ignored because the resampling object does not include a validation set."
)
}
dat <- x$splits[[1]]$data
}
if (verbose) {
cli::cli_inform(c("i" = "Fitting using {nrow(dat)} data points..."))
}
Expand Down
51 changes: 46 additions & 5 deletions R/last_fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
#' @param preprocessor A traditional model formula or a recipe created using
#' [recipes::recipe()].
#'
#' @param split An `rsplit` object created from [rsample::initial_split()].
#' @param split An `rsplit` object created from [rsample::initial_split()] or
#' [rsample::initial_validation_split()].
#'
#' @param metrics A [yardstick::metric_set()], or `NULL` to compute a standard
#' set of metrics.
Expand All @@ -25,6 +26,11 @@
#' values should be non-negative and should probably be no greater then the
#' largest event time in the training set.
#'
#' @param add_validation_set For 3-way splits into training, validation, and test
#' set via [rsample::initial_validation_split()], should the validation set be
#' included in the data set used to train the model. If not, only the training
#' set is used.
#'
#' @param ... Currently unused.
#'
#' @details
Expand Down Expand Up @@ -113,7 +119,8 @@ last_fit.model_fit <- function(object, ...) {
#' @export
#' @rdname last_fit
last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL,
control = control_last_fit(), eval_time = NULL) {
control = control_last_fit(), eval_time = NULL,
add_validation_set = FALSE) {
if (rlang::is_missing(preprocessor) || !is_preprocessor(preprocessor)) {
rlang::abort(paste(
"To tune a model spec, you must preprocess",
Expand All @@ -133,19 +140,20 @@ last_fit.model_spec <- function(object, preprocessor, split, ..., metrics = NULL
wflow <- add_formula(wflow, preprocessor)
}

last_fit_workflow(wflow, split, metrics, control, eval_time)
last_fit_workflow(wflow, split, metrics, control, eval_time, add_validation_set)
}


#' @rdname last_fit
#' @export
last_fit.workflow <- function(object, split, ..., metrics = NULL,
control = control_last_fit(), eval_time = NULL) {
control = control_last_fit(), eval_time = NULL,
add_validation_set = FALSE) {
empty_ellipses(...)

control <- parsnip::condense_control(control, control_last_fit())

last_fit_workflow(object, split, metrics, control, eval_time)
last_fit_workflow(object, split, metrics, control, eval_time, add_validation_set)
}


Expand All @@ -154,6 +162,7 @@ last_fit_workflow <- function(object,
metrics,
control,
eval_time = NULL,
add_validation_set = FALSE,
...,
call = rlang::caller_env()) {
rlang::check_dots_empty()
Expand All @@ -166,6 +175,9 @@ last_fit_workflow <- function(object,
)
}

if (inherits(split, "initial_validation_split")) {
split <- prepare_validation_split(split, add_validation_set)
}
splits <- list(split)
resamples <- rsample::manual_rset(splits, ids = "train/test split")

Expand All @@ -190,3 +202,32 @@ last_fit_workflow <- function(object,
.stash_last_result(res)
res
}


prepare_validation_split <- function(split, add_validation_set){
if (add_validation_set) {
# equivalent to (unexported) rsample:::rsplit() without checks
split <- structure(
list(
data = split$data,
in_id = c(split$train_id, split$val_id),
out_id = NA
),
class = "rsplit"
)
} else {
id_train_test <- seq_len(nrow(split$data))[-sort(split$val_id)]
id_train <- match(split$train_id, id_train_test)

split <- structure(
list(
data = split$data[-sort(split$val_id), , drop = FALSE],
in_id = id_train,
out_id = NA
),
class = "rsplit"
)
}

split
}
16 changes: 15 additions & 1 deletion man/fit_best.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 11 additions & 3 deletions man/last_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 65 additions & 0 deletions tests/testthat/test-fit_best.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,68 @@ test_that("fit_best", {
fit_best(ames_iter_search)
)
})

test_that("fit_best() works with validation split: 3-way split", {
skip_if_not_installed("kknn")
skip_if_not_installed("modeldata")
data(ames, package = "modeldata", envir = rlang::current_env())

set.seed(23598723)
initial_val_split <- rsample::initial_validation_split(ames)
val_set <- validation_set(initial_val_split)

f <- Sale_Price ~ Gr_Liv_Area + Year_Built
knn_mod <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression")
wflow <- workflow(f, knn_mod)

tune_res <- tune_grid(
wflow,
grid = tibble(neighbors = c(1, 5)),
resamples = val_set,
control = control_grid(save_workflow = TRUE)
) %>% suppressWarnings()
set.seed(3)
fit_on_train <- fit_best(tune_res)
pred <- predict(fit_on_train, testing(initial_val_split))

set.seed(3)
exp_fit_on_train <- nearest_neighbor(neighbors = 5) %>%
set_mode("regression") %>%
fit(f, training(initial_val_split))
exp_pred <- predict(exp_fit_on_train, testing(initial_val_split))

expect_equal(pred, exp_pred)
})

test_that("fit_best() works with validation split: 2x 2-way splits", {
skip_if_not_installed("kknn")
skip_if_not_installed("modeldata")
data(ames, package = "modeldata", envir = rlang::current_env())

set.seed(23598723)
split <- rsample::initial_split(ames)
train_and_val <- training(split)
val_set <- rsample::validation_split(train_and_val)

f <- Sale_Price ~ Gr_Liv_Area + Year_Built
knn_mod <- nearest_neighbor(neighbors = tune()) %>% set_mode("regression")
wflow <- workflow(f, knn_mod)

tune_res <- tune_grid(
wflow,
grid = tibble(neighbors = c(1, 5)),
resamples = val_set,
control = control_grid(save_workflow = TRUE)
)
set.seed(3)
fit_on_train_and_val <- fit_best(tune_res)
pred <- predict(fit_on_train_and_val, testing(split))

set.seed(3)
exp_fit_on_train_and_val <- nearest_neighbor(neighbors = 5) %>%
set_mode("regression") %>%
fit(f, train_and_val)
exp_pred <- predict(exp_fit_on_train_and_val, testing(split))

expect_equal(pred, exp_pred)
})
65 changes: 65 additions & 0 deletions tests/testthat/test-last-fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -163,3 +163,68 @@ test_that("`last_fit()` when objects need tuning", {
expect_snapshot_error(last_fit(wflow_2, split))
expect_snapshot_error(last_fit(wflow_3, split))
})

test_that("last_fit() excludes validation set for initial_validation_split objects", {
skip_if_not_installed("modeldata")
data(ames, package = "modeldata", envir = rlang::current_env())

set.seed(23598723)
split <- rsample::initial_validation_split(ames)

f <- Sale_Price ~ Gr_Liv_Area + Year_Built
lm_fit <- lm(f, data = rsample::training(split))
test_pred <- predict(lm_fit, rsample::testing(split))
rmse_test <- yardstick::rsq_vec(rsample::testing(split) %>% pull(Sale_Price), test_pred)

res <- parsnip::linear_reg() %>%
parsnip::set_engine("lm") %>%
last_fit(f, split)

expect_equal(res, .Last.tune.result)

expect_equal(
coef(extract_fit_engine(res$.workflow[[1]])),
coef(lm_fit),
ignore_attr = TRUE
)
expect_equal(res$.metrics[[1]]$.estimate[[2]], rmse_test)
expect_equal(res$.predictions[[1]]$.pred, unname(test_pred))
expect_true(res$.workflow[[1]]$trained)
expect_equal(
nrow(predict(res$.workflow[[1]], rsample::testing(split))),
nrow(rsample::testing(split))
)
})

test_that("last_fit() can include validation set for initial_validation_split objects", {
skip_if_not_installed("modeldata")
data(ames, package = "modeldata", envir = rlang::current_env())

set.seed(23598723)
split <- rsample::initial_validation_split(ames)

f <- Sale_Price ~ Gr_Liv_Area + Year_Built
train_val <- rbind(rsample::training(split), rsample::validation(split))
lm_fit <- lm(f, data = train_val)
test_pred <- predict(lm_fit, rsample::testing(split))
rmse_test <- yardstick::rsq_vec(rsample::testing(split) %>% pull(Sale_Price), test_pred)

res <- parsnip::linear_reg() %>%
parsnip::set_engine("lm") %>%
last_fit(f, split, add_validation_set = TRUE)

expect_equal(res, .Last.tune.result)

expect_equal(
coef(extract_fit_engine(res$.workflow[[1]])),
coef(lm_fit),
ignore_attr = TRUE
)
expect_equal(res$.metrics[[1]]$.estimate[[2]], rmse_test)
expect_equal(res$.predictions[[1]]$.pred, unname(test_pred))
expect_true(res$.workflow[[1]]$trained)
expect_equal(
nrow(predict(res$.workflow[[1]], rsample::testing(split))),
nrow(rsample::testing(split))
)
})

0 comments on commit eaca118

Please sign in to comment.