Skip to content

Commit

Permalink
add suport for workflows::add_case_weights() and fix #145
Browse files Browse the repository at this point in the history
  • Loading branch information
cregouby committed Feb 17, 2024
1 parent 8f2fbf4 commit f2a700d
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 23 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## New features

* parsnip models now allow transparently passing case weights through `workflows::add_case_weights()` parameters (#151)
* parsnip models now support `tabnet_model` and `from_epoch` parameters (#143)

## Bugfixes
Expand Down
10 changes: 7 additions & 3 deletions R/hardhat.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#' If no argument is supplied, this will use the default values in [tabnet_config()].
#' @param from_epoch When a `tabnet_model` is provided, restore the network weights from a specific epoch.
#' Default is last available checkpoint for restored model, or last epoch for in-memory model.
#' @param weights Unused.
#' @param ... Model hyperparameters.
#' Any hyperparameters set here will update those set by the config argument.
#' See [tabnet_config()] for a list of all possible hyperparameters.
Expand Down Expand Up @@ -111,7 +112,8 @@ tabnet_fit.default <- function(x, ...) {

#' @export
#' @rdname tabnet_fit
tabnet_fit.data.frame <- function(x, y, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) {
tabnet_fit.data.frame <- function(x, y, tabnet_model = NULL, config = tabnet_config(), ...,
from_epoch = NULL, weights = NULL) {
processed <- hardhat::mold(x, y)
check_type(processed$outcomes)

Expand All @@ -130,7 +132,8 @@ tabnet_fit.data.frame <- function(x, y, tabnet_model = NULL, config = tabnet_con

#' @export
#' @rdname tabnet_fit
tabnet_fit.formula <- function(formula, data, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) {
tabnet_fit.formula <- function(formula, data, tabnet_model = NULL, config = tabnet_config(), ...,
from_epoch = NULL, weights = NULL) {
processed <- hardhat::mold(
formula, data,
blueprint = hardhat::default_formula_blueprint(
Expand All @@ -155,7 +158,8 @@ tabnet_fit.formula <- function(formula, data, tabnet_model = NULL, config = tabn

#' @export
#' @rdname tabnet_fit
tabnet_fit.recipe <- function(x, data, tabnet_model = NULL, config = tabnet_config(), ..., from_epoch = NULL) {
tabnet_fit.recipe <- function(x, data, tabnet_model = NULL, config = tabnet_config(), ...,
from_epoch = NULL, weights = NULL) {
processed <- hardhat::mold(x, data)
check_type(processed$outcomes)

Expand Down
4 changes: 2 additions & 2 deletions R/parsnip.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ add_parsnip_tabnet <- function() {
mode = "classification",
value = list(
interface = "formula",
protect = c("formula", "data"),
protect = c("formula", "data", "weights"),
func = c(pkg = "tabnet", fun = "tabnet_fit"),
defaults = list()
)
Expand All @@ -36,7 +36,7 @@ add_parsnip_tabnet <- function() {
mode = "regression",
value = list(
interface = "formula",
protect = c("formula", "data"),
protect = c("formula", "data", "weights"),
func = c(pkg = "tabnet", fun = "tabnet_fit"),
defaults = list()
)
Expand Down
11 changes: 8 additions & 3 deletions man/tabnet_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

43 changes: 28 additions & 15 deletions tests/testthat/test-parsnip.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ test_that("parsnip fit model works", {

expect_no_error(
fit <- model %>%
parsnip::fit(Sale_Price ~ ., data = ames)
parsnip::fit(Sale_Price ~ ., data = small_ames)
)

# some setup params
Expand All @@ -21,7 +21,7 @@ test_that("parsnip fit model works", {

expect_no_error(
fit <- model %>%
parsnip::fit(Sale_Price ~ ., data = ames)
parsnip::fit(Sale_Price ~ ., data = small_ames)
)

# new batch of setup params
Expand All @@ -33,7 +33,7 @@ test_that("parsnip fit model works", {

expect_no_error(
fit <- model %>%
parsnip::fit(Overall_Cond ~ ., data = ames)
parsnip::fit(Overall_Cond ~ ., data = small_ames)
)

})
Expand All @@ -49,7 +49,7 @@ test_that("parsnip fit model works from a pretrained model", {

expect_no_error(
fit <- model %>%
parsnip::fit(Sale_Price ~ ., data = ames)
parsnip::fit(Sale_Price ~ ., data = small_ames)
)


Expand All @@ -63,23 +63,19 @@ test_that("multi_predict works as expected", {
parsnip::set_mode("regression") %>%
parsnip::set_engine("torch")

data("ames", package = "modeldata")

expect_no_error(
fit <- model %>%
parsnip::fit(Sale_Price ~ ., data = ames)
parsnip::fit(Sale_Price ~ ., data = small_ames)
)

preds <- parsnip::multi_predict(fit, ames, epochs = c(1,2,3,4,5))
preds <- parsnip::multi_predict(fit, small_ames, epochs = c(1,2,3,4,5))

expect_equal(nrow(preds), nrow(ames))
expect_equal(nrow(preds), nrow(small_ames))
expect_equal(nrow(preds$.pred[[1]]), 5)
})

test_that("Check we can finalize a workflow", {

data("ames", package = "modeldata")

model <- tabnet(penalty = tune(), epochs = tune()) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("torch")
Expand All @@ -91,7 +87,7 @@ test_that("Check we can finalize a workflow", {
wf <- tune::finalize_workflow(wf, tibble::tibble(penalty = 0.01, epochs = 1))

expect_no_error(
fit <- wf %>% parsnip::fit(data = ames)
fit <- wf %>% parsnip::fit(data = small_ames)
)

expect_equal(rlang::eval_tidy(wf$fit$actions$model$spec$args$penalty), 0.01)
Expand All @@ -100,8 +96,6 @@ test_that("Check we can finalize a workflow", {

test_that("Check we can finalize a workflow from a tune_grid", {

data("ames", package = "modeldata")

model <- tabnet(epochs = tune(), checkpoint_epochs = 1) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("torch")
Expand All @@ -111,7 +105,7 @@ test_that("Check we can finalize a workflow from a tune_grid", {
workflows::add_formula(Sale_Price ~ .)

custom_grid <- tidyr::crossing(epochs = c(1,2,3))
cv_folds <- ames %>%
cv_folds <- small_ames %>%
rsample::vfold_cv(v = 2, repeats = 1)

at <- tune::tune_grid(
Expand Down Expand Up @@ -212,3 +206,22 @@ test_that("tabnet grid reduction - torch", {
expect_equal(reg_grid_smol$.submodels[[i]], list(`Ade Tukunbo` = 1:2))
}
})

test_that("Check workflow can use case_weight", {

small_ames_cw <- small_ames %>% dplyr::mutate(case_weight = hardhat::frequency_weights(Year_Sold))
model <- tabnet(epochs = 3, checkpoint_epochs = 1) %>%
parsnip::set_mode("regression") %>%
parsnip::set_engine("torch")

wf <- workflows::workflow() %>%
workflows::add_model(model) %>%
workflows::add_formula(Sale_Price ~ .) %>%
workflows::add_case_weights(case_weight)

expect_no_error(
fit <- wf %>% parsnip::fit(data = small_ames_cw)
)


})

0 comments on commit f2a700d

Please sign in to comment.