diff --git a/NEWS.md b/NEWS.md index a2a3aacd..6e35a852 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,7 @@ ## New features +* parsnip models now allow transparently passing case weights through `workflows::add_case_weights()` parameters (#151) * parsnip models now support `tabnet_model` and `from_epoch` parameters (#143) ## Bugfixes diff --git a/R/hardhat.R b/R/hardhat.R index f8f70842..102bb1ab 100644 --- a/R/hardhat.R +++ b/R/hardhat.R @@ -32,6 +32,7 @@ #' If no argument is supplied, this will use the default values in [tabnet_config()]. #' @param from_epoch When a `tabnet_model` is provided, restore the network weights from a specific epoch. #' Default is last available checkpoint for restored model, or last epoch for in-memory model. +#' @param weights Unused. #' @param ... Model hyperparameters. #' Any hyperparameters set here will update those set by the config argument. #' See [tabnet_config()] for a list of all possible hyperparameters. @@ -111,7 +112,8 @@ tabnet_fit.default <- function(x, ...) { #' @export #' @rdname tabnet_fit -tabnet_fit.data.frame <- function(x, y, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) { +tabnet_fit.data.frame <- function(x, y, tabnet_model = NULL, config = tabnet_config(), ..., + from_epoch = NULL, weights = NULL) { processed <- hardhat::mold(x, y) check_type(processed$outcomes) @@ -130,7 +132,8 @@ tabnet_fit.data.frame <- function(x, y, tabnet_model = NULL, config = tabnet_con #' @export #' @rdname tabnet_fit -tabnet_fit.formula <- function(formula, data, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) { +tabnet_fit.formula <- function(formula, data, tabnet_model = NULL, config = tabnet_config(), ..., + from_epoch = NULL, weights = NULL) { processed <- hardhat::mold( formula, data, blueprint = hardhat::default_formula_blueprint( @@ -155,7 +158,8 @@ tabnet_fit.formula <- function(formula, data, tabnet_model = NULL, config = tabn #' @export #' @rdname tabnet_fit -tabnet_fit.recipe <- function(x, data, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) { +tabnet_fit.recipe <- function(x, data, tabnet_model = NULL, config = tabnet_config(), ..., + from_epoch = NULL, weights = NULL) { processed <- hardhat::mold(x, data) check_type(processed$outcomes) diff --git a/R/parsnip.R b/R/parsnip.R index ab8dc0b1..98bfb728 100644 --- a/R/parsnip.R +++ b/R/parsnip.R @@ -24,7 +24,7 @@ add_parsnip_tabnet <- function() { mode = "classification", value = list( interface = "formula", - protect = c("formula", "data"), + protect = c("formula", "data", "weights"), func = c(pkg = "tabnet", fun = "tabnet_fit"), defaults = list() ) @@ -36,7 +36,7 @@ add_parsnip_tabnet <- function() { mode = "regression", value = list( interface = "formula", - protect = c("formula", "data"), + protect = c("formula", "data", "weights"), func = c(pkg = "tabnet", fun = "tabnet_fit"), defaults = list() ) diff --git a/man/tabnet_fit.Rd b/man/tabnet_fit.Rd index 19a6b94c..3c1d2c9f 100644 --- a/man/tabnet_fit.Rd +++ b/man/tabnet_fit.Rd @@ -19,7 +19,8 @@ tabnet_fit(x, ...) tabnet_model = NULL, config = tabnet_config(), ..., - from_epoch = NULL + from_epoch = NULL, + weights = NULL ) \method{tabnet_fit}{formula}( @@ -28,7 +29,8 @@ tabnet_fit(x, ...) tabnet_model = NULL, config = tabnet_config(), ..., - from_epoch = NULL + from_epoch = NULL, + weights = NULL ) \method{tabnet_fit}{recipe}( @@ -37,7 +39,8 @@ tabnet_fit(x, ...) tabnet_model = NULL, config = tabnet_config(), ..., - from_epoch = NULL + from_epoch = NULL, + weights = NULL ) \method{tabnet_fit}{Node}( @@ -82,6 +85,8 @@ If no argument is supplied, this will use the default values in \code{\link[=tab \item{from_epoch}{When a \code{tabnet_model} is provided, restore the network weights from a specific epoch. Default is last available checkpoint for restored model, or last epoch for in-memory model.} +\item{weights}{Unused.} + \item{formula}{A formula specifying the outcome terms on the left-hand side, and the predictor terms on the right-hand side.} diff --git a/tests/testthat/test-parsnip.R b/tests/testthat/test-parsnip.R index f741914e..a7748069 100644 --- a/tests/testthat/test-parsnip.R +++ b/tests/testthat/test-parsnip.R @@ -9,7 +9,7 @@ test_that("parsnip fit model works", { expect_no_error( fit <- model %>% - parsnip::fit(Sale_Price ~ ., data = ames) + parsnip::fit(Sale_Price ~ ., data = small_ames) ) # some setup params @@ -21,7 +21,7 @@ test_that("parsnip fit model works", { expect_no_error( fit <- model %>% - parsnip::fit(Sale_Price ~ ., data = ames) + parsnip::fit(Sale_Price ~ ., data = small_ames) ) # new batch of setup params @@ -33,7 +33,7 @@ test_that("parsnip fit model works", { expect_no_error( fit <- model %>% - parsnip::fit(Overall_Cond ~ ., data = ames) + parsnip::fit(Overall_Cond ~ ., data = small_ames) ) }) @@ -49,7 +49,7 @@ test_that("parsnip fit model works from a pretrained model", { expect_no_error( fit <- model %>% - parsnip::fit(Sale_Price ~ ., data = ames) + parsnip::fit(Sale_Price ~ ., data = small_ames) ) @@ -63,23 +63,19 @@ test_that("multi_predict works as expected", { parsnip::set_mode("regression") %>% parsnip::set_engine("torch") - data("ames", package = "modeldata") - expect_no_error( fit <- model %>% - parsnip::fit(Sale_Price ~ ., data = ames) + parsnip::fit(Sale_Price ~ ., data = small_ames) ) - preds <- parsnip::multi_predict(fit, ames, epochs = c(1,2,3,4,5)) + preds <- parsnip::multi_predict(fit, small_ames, epochs = c(1,2,3,4,5)) - expect_equal(nrow(preds), nrow(ames)) + expect_equal(nrow(preds), nrow(small_ames)) expect_equal(nrow(preds$.pred[[1]]), 5) }) test_that("Check we can finalize a workflow", { - data("ames", package = "modeldata") - model <- tabnet(penalty = tune(), epochs = tune()) %>% parsnip::set_mode("regression") %>% parsnip::set_engine("torch") @@ -91,7 +87,7 @@ test_that("Check we can finalize a workflow", { wf <- tune::finalize_workflow(wf, tibble::tibble(penalty = 0.01, epochs = 1)) expect_no_error( - fit <- wf %>% parsnip::fit(data = ames) + fit <- wf %>% parsnip::fit(data = small_ames) ) expect_equal(rlang::eval_tidy(wf$fit$actions$model$spec$args$penalty), 0.01) @@ -100,8 +96,6 @@ test_that("Check we can finalize a workflow", { test_that("Check we can finalize a workflow from a tune_grid", { - data("ames", package = "modeldata") - model <- tabnet(epochs = tune(), checkpoint_epochs = 1) %>% parsnip::set_mode("regression") %>% parsnip::set_engine("torch") @@ -111,7 +105,7 @@ test_that("Check we can finalize a workflow from a tune_grid", { workflows::add_formula(Sale_Price ~ .) custom_grid <- tidyr::crossing(epochs = c(1,2,3)) - cv_folds <- ames %>% + cv_folds <- small_ames %>% rsample::vfold_cv(v = 2, repeats = 1) at <- tune::tune_grid( @@ -212,3 +206,22 @@ test_that("tabnet grid reduction - torch", { expect_equal(reg_grid_smol$.submodels[[i]], list(`Ade Tukunbo` = 1:2)) } }) + +test_that("Check workflow can use case_weight", { + + small_ames_cw <- small_ames %>% dplyr::mutate(case_weight = hardhat::frequency_weights(Year_Sold)) + model <- tabnet(epochs = 3, checkpoint_epochs = 1) %>% + parsnip::set_mode("regression") %>% + parsnip::set_engine("torch") + + wf <- workflows::workflow() %>% + workflows::add_model(model) %>% + workflows::add_formula(Sale_Price ~ .) %>% + workflows::add_case_weights(case_weight) + + expect_no_error( + fit <- wf %>% parsnip::fit(data = small_ames_cw) + ) + + +})