-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature request for case-weights
#145
Comments
Hello @cgoo4 pak::pkg_install("mlverse/tabnet@feature/case_weight") |
Hi @cregouby - Thank you. Ahead of trying it on my own data, I've made a quick test using the toy library(tabnet)
library(tidymodels)
library(modeldata)
library(patchwork)
data("lending_club", package = "modeldata")
class_ratio <- lending_club |>
summarise(sum(Class == "good") / sum(Class == "bad")) |>
pull()
lending_club <- lending_club |>
mutate(
case_wts = if_else(Class == "bad", class_ratio, 1),
case_wts = importance_weights(case_wts)
)
set.seed(123)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test <- testing(split)
tab_rec <-
train |>
recipe() |>
update_role(Class, new_role = "outcome") |>
update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")
xgb_rec <- tab_rec |>
step_dummy(term, sub_grade, addr_state, verification_status, emp_length)
tab_mod <- tabnet(epochs = 100) |>
set_engine("torch", device = "cpu") |>
set_mode("classification")
xgb_mod <- boost_tree(trees = 100) |>
set_engine("xgboost") |>
set_mode("classification")
tab_wf <- workflow() |>
add_model(tab_mod) |>
add_recipe(tab_rec) |>
add_case_weights(case_wts)
xgb_wf <- workflow() |>
add_model(xgb_mod) |>
add_recipe(xgb_rec) |>
add_case_weights(case_wts)
tab_fit <- tab_wf |> fit(train)
xgb_fit <- xgb_wf |> fit(train)
tab_test <- tab_fit |> augment(test)
xgb_test <- xgb_fit |> augment(test)
p1 <- tab_test |>
pr_curve(Class, .pred_good, case_weights = case_wts) |>
autoplot() +
ggtitle("TabNet with Case Weights") +
theme(plot.title = element_text(size = 9))
p2 <- tab_test |>
pr_curve(Class, .pred_good) |>
autoplot() +
ggtitle("TabNet WITHOUT") +
theme(plot.title = element_text(size = 9))
p3 <- xgb_test |>
pr_curve(Class, .pred_good, case_weights = case_wts) |>
autoplot() +
ggtitle("XGBoost with Case Weights") +
theme(plot.title = element_text(size = 9))
p4 <- xgb_test |>
pr_curve(Class, .pred_good) |>
autoplot() +
ggtitle("XGBoost WITHOUT") +
theme(plot.title = element_text(size = 9))
p1 + p2 + p3 + p4 Created on 2024-02-18 with reprex v2.1.0 |
* add suport for workflows::add_case_weights() and fix #145 * add message for unused `weights` and translation
In 0.6.0.9000 I'm getting the message (It's the same example as per above where the weights were being passed along from the workflow.) library(tabnet)
library(tidymodels)
data("lending_club", package = "modeldata")
class_ratio <- lending_club |>
summarise(sum(Class == "good") / sum(Class == "bad")) |>
pull()
lending_club <- lending_club |>
mutate(
case_wts = if_else(Class == "bad", class_ratio, 1),
case_wts = importance_weights(case_wts)
)
set.seed(123)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test <- testing(split)
tab_rec <-
train |>
recipe() |>
update_role(Class, new_role = "outcome") |>
update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")
tab_mod <- tabnet(epochs = 10) |>
set_engine("torch", device = "cpu") |>
set_mode("classification")
tab_wf <- workflow() |>
add_model(tab_mod) |>
add_recipe(tab_rec) |>
add_case_weights(case_wts)
tab_fit <- tab_wf |> fit(train)
#> Configured `weights` will not be used
tab_test <- tab_fit |> augment(test) Created on 2024-08-09 with reprex v2.1.1
|
Hello @cgoo4, I added the message on purpose and it is maybe misleading. The meaning is 'tabnet model will be fit without using the case_weights variable.' as this is the actual usage of case_weights variable by tabnet, they are let appart for later-on usage by other downstream tydimodel packages. Any proposal for a more informative message ? |
Hi @cregouby - Thank you for clarifying. If it's possible to set case_weights more than one way, e.g. in a tidymodels |
Would it be possible to add support for case weights in TabNet?
This would help with a class imbalance and make it easier to compare (and blend) the results of TabNet and XGBoost.
(I will probably upsample the minority class in the meantime as an alternative approach.)
This would be the desired workflow:
Created on 2024-01-12 with reprex v2.0.2
The text was updated successfully, but these errors were encountered: