Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature request for case-weights #145

Open
cgoo4 opened this issue Jan 12, 2024 · 5 comments · Fixed by #151
Open

Feature request for case-weights #145

cgoo4 opened this issue Jan 12, 2024 · 5 comments · Fixed by #151
Milestone

Comments

@cgoo4
Copy link

cgoo4 commented Jan 12, 2024

Would it be possible to add support for case weights in TabNet?

This would help with a class imbalance and make it easier to compare (and blend) the results of TabNet and XGBoost.

(I will probably upsample the minority class in the meantime as an alternative approach.)

This would be the desired workflow:

library(tabnet)
library(tidymodels)
library(modeldata)

set.seed(123)
data("lending_club", package = "modeldata")

class_ratio <- lending_club |> 
  summarise(sum(Class == "good") / sum(Class == "bad")) |> 
  pull()

lending_club <- lending_club |>
  mutate(
    case_wts = if_else(Class == "bad", class_ratio, 1),
    case_wts = importance_weights(case_wts)
  )

split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")

set.seed(1)

tab_mod <- tabnet(epochs = 10) |> 
  set_engine("torch", device = "cpu") |> 
  set_mode("classification")

tab_wf <- workflow() |> 
  add_model(tab_mod) |> 
  add_recipe(tab_rec) |> 
  add_case_weights(case_wts)

tab_fit <- tab_wf |> fit(train)
#> Error in `check_case_weights()`:
#> ! Case weights are not enabled by the underlying model implementation.
#> Backtrace:
#>      ▆
#>   1. ├─generics::fit(tab_wf, train)
#>   2. └─workflows:::fit.workflow(tab_wf, train)
#>   3.   └─workflows::.fit_model(workflow, control)
#>   4.     ├─generics::fit(action_model, workflow = workflow, control = control)
#>   5.     └─workflows:::fit.action_model(...)
#>   6.       └─workflows:::fit_from_xy(spec, mold, case_weights, control_parsnip)
#>   7.         ├─generics::fit_xy(...)
#>   8.         └─parsnip::fit_xy.model_spec(...)
#>   9.           └─parsnip:::check_case_weights(case_weights, object)
#>  10.             └─rlang::abort("Case weights are not enabled by the underlying model implementation.")

Created on 2024-01-12 with reprex v2.0.2

@cregouby
Copy link
Collaborator

Hello @cgoo4
I finally did it. Would you like to test it and report if this fits your need ?
One way to install it is

pak::pkg_install("mlverse/tabnet@feature/case_weight")

@cgoo4
Copy link
Author

cgoo4 commented Feb 18, 2024

Hi @cregouby - Thank you.

Ahead of trying it on my own data, I've made a quick test using the toy lending_club data. Untuned TabNet and XGBoost models, with and without case weights, show comparable results!

library(tabnet)
library(tidymodels)
library(modeldata)
library(patchwork)

data("lending_club", package = "modeldata")

class_ratio <- lending_club |> 
  summarise(sum(Class == "good") / sum(Class == "bad")) |> 
  pull()

lending_club <- lending_club |>
  mutate(
    case_wts = if_else(Class == "bad", class_ratio, 1),
    case_wts = importance_weights(case_wts)
  )

set.seed(123)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")

xgb_rec <- tab_rec |> 
  step_dummy(term, sub_grade, addr_state, verification_status, emp_length)

tab_mod <- tabnet(epochs = 100) |> 
  set_engine("torch", device = "cpu") |> 
  set_mode("classification")

xgb_mod <- boost_tree(trees = 100) |> 
  set_engine("xgboost") |> 
  set_mode("classification")

tab_wf <- workflow() |> 
  add_model(tab_mod) |> 
  add_recipe(tab_rec) |> 
  add_case_weights(case_wts)

xgb_wf <- workflow() |> 
  add_model(xgb_mod) |> 
  add_recipe(xgb_rec) |> 
  add_case_weights(case_wts)

tab_fit <- tab_wf |> fit(train)
xgb_fit <- xgb_wf |> fit(train)

tab_test <- tab_fit |> augment(test)
xgb_test <- xgb_fit |> augment(test)

p1 <- tab_test |> 
  pr_curve(Class, .pred_good, case_weights = case_wts) |> 
  autoplot() +
  ggtitle("TabNet with Case Weights") +
  theme(plot.title = element_text(size = 9))

p2 <- tab_test |> 
  pr_curve(Class, .pred_good) |> 
  autoplot() +
  ggtitle("TabNet WITHOUT") +
  theme(plot.title = element_text(size = 9))

p3 <- xgb_test |> 
  pr_curve(Class, .pred_good, case_weights = case_wts) |> 
  autoplot() +
  ggtitle("XGBoost with Case Weights") +
  theme(plot.title = element_text(size = 9))

p4 <- xgb_test |> 
  pr_curve(Class, .pred_good) |> 
  autoplot() +
  ggtitle("XGBoost WITHOUT") +
  theme(plot.title = element_text(size = 9))

p1 + p2 + p3 + p4

Created on 2024-02-18 with reprex v2.1.0

@cgoo4 cgoo4 closed this as completed Feb 19, 2024
cregouby added a commit that referenced this issue Feb 21, 2024
* add suport for workflows::add_case_weights() and fix #145
* add message for unused `weights` and translation
@cgoo4
Copy link
Author

cgoo4 commented Aug 9, 2024

In 0.6.0.9000 I'm getting the message Configured weights will not be used:

(It's the same example as per above where the weights were being passed along from the workflow.)

library(tabnet)
library(tidymodels)

data("lending_club", package = "modeldata")

class_ratio <- lending_club |> 
  summarise(sum(Class == "good") / sum(Class == "bad")) |> 
  pull()

lending_club <- lending_club |>
  mutate(
    case_wts = if_else(Class == "bad", class_ratio, 1),
    case_wts = importance_weights(case_wts)
  )

set.seed(123)
split <- initial_split(lending_club, strata = Class)
train <- training(split)
test  <- testing(split)

tab_rec <-
  train |>
  recipe() |>
  update_role(Class, new_role = "outcome") |>
  update_role(-has_role(c("outcome", "id", "case_weights")), new_role = "predictor")

tab_mod <- tabnet(epochs = 10) |> 
  set_engine("torch", device = "cpu") |> 
  set_mode("classification")

tab_wf <- workflow() |> 
  add_model(tab_mod) |> 
  add_recipe(tab_rec) |> 
  add_case_weights(case_wts)

tab_fit <- tab_wf |> fit(train)
#> Configured `weights` will not be used

tab_test <- tab_fit |> augment(test)

Created on 2024-08-09 with reprex v2.1.1

Session info

sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.1 (2024-06-14)
#>  os       macOS Sonoma 14.5
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Europe/London
#>  date     2024-08-09
#>  pandoc   3.2.1 @ /opt/homebrew/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  backports      1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>  bit            4.0.5      2022-11-15 [1] CRAN (R 4.4.0)
#>  bit64          4.0.5      2020-08-30 [1] CRAN (R 4.4.0)
#>  broom        * 1.0.6      2024-05-17 [1] CRAN (R 4.4.0)
#>  callr          3.7.6      2024-03-25 [1] CRAN (R 4.4.0)
#>  class          7.3-22     2023-05-03 [2] CRAN (R 4.4.1)
#>  cli            3.6.3      2024-06-21 [1] CRAN (R 4.4.0)
#>  codetools      0.2-20     2024-03-31 [2] CRAN (R 4.4.1)
#>  colorspace     2.1-1      2024-07-26 [1] CRAN (R 4.4.0)
#>  coro           1.0.4      2024-03-11 [1] CRAN (R 4.4.0)
#>  data.table     1.15.4     2024-03-30 [1] CRAN (R 4.4.0)
#>  dials        * 1.3.0      2024-07-30 [1] CRAN (R 4.4.0)
#>  DiceDesign     1.10       2023-12-07 [1] CRAN (R 4.4.0)
#>  digest         0.6.36     2024-06-23 [1] CRAN (R 4.4.0)
#>  dplyr        * 1.1.4      2023-11-17 [1] CRAN (R 4.4.0)
#>  evaluate       0.24.0     2024-06-10 [1] CRAN (R 4.4.0)
#>  fansi          1.0.6      2023-12-08 [1] CRAN (R 4.4.0)
#>  fastmap        1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.4.0)
#>  fs             1.6.4      2024-04-25 [1] CRAN (R 4.4.0)
#>  furrr          0.3.1      2022-08-15 [1] CRAN (R 4.4.0)
#>  future         1.34.0     2024-07-29 [1] CRAN (R 4.4.0)
#>  future.apply   1.11.2     2024-03-28 [1] CRAN (R 4.4.0)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.4.0)
#>  ggplot2      * 3.5.1      2024-04-23 [1] CRAN (R 4.4.0)
#>  globals        0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>  glue           1.7.0      2024-01-09 [1] CRAN (R 4.4.0)
#>  gower          1.0.1      2022-12-22 [1] CRAN (R 4.4.0)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.4.0)
#>  gtable         0.3.5      2024-04-22 [1] CRAN (R 4.4.0)
#>  hardhat        1.4.0      2024-06-02 [1] CRAN (R 4.4.0)
#>  htmltools      0.5.8.1    2024-04-04 [1] CRAN (R 4.4.0)
#>  infer        * 1.0.7      2024-03-25 [1] CRAN (R 4.4.0)
#>  ipred          0.9-15     2024-07-18 [1] CRAN (R 4.4.0)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.4.0)
#>  jsonlite       1.8.8      2023-12-04 [1] CRAN (R 4.4.0)
#>  knitr          1.48       2024-07-07 [1] CRAN (R 4.4.0)
#>  lattice        0.22-6     2024-03-20 [2] CRAN (R 4.4.1)
#>  lava           1.8.0      2024-03-05 [1] CRAN (R 4.4.0)
#>  lhs            1.2.0      2024-06-30 [1] CRAN (R 4.4.0)
#>  lifecycle      1.0.4      2023-11-07 [1] CRAN (R 4.4.0)
#>  listenv        0.9.1      2024-01-29 [1] CRAN (R 4.4.0)
#>  lubridate      1.9.3      2023-09-27 [1] CRAN (R 4.4.0)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.4.0)
#>  MASS           7.3-60.2   2024-04-26 [2] CRAN (R 4.4.1)
#>  Matrix         1.7-0      2024-04-26 [2] CRAN (R 4.4.1)
#>  modeldata    * 1.4.0      2024-06-19 [1] CRAN (R 4.4.0)
#>  munsell        0.5.1      2024-04-01 [1] CRAN (R 4.4.0)
#>  nnet           7.3-19     2023-05-03 [2] CRAN (R 4.4.1)
#>  parallelly     1.38.0     2024-07-27 [1] CRAN (R 4.4.0)
#>  parsnip      * 1.2.1      2024-03-22 [1] CRAN (R 4.4.0)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.4.0)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.4.0)
#>  processx       3.8.4      2024-03-16 [1] CRAN (R 4.4.0)
#>  prodlim        2024.06.25 2024-06-24 [1] CRAN (R 4.4.0)
#>  ps             1.7.7      2024-07-02 [1] CRAN (R 4.4.0)
#>  purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.4.0)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.4.0)
#>  Rcpp           1.0.13     2024-07-17 [1] CRAN (R 4.4.0)
#>  recipes      * 1.1.0      2024-07-04 [1] CRAN (R 4.4.0)
#>  reprex         2.1.1      2024-07-06 [1] CRAN (R 4.4.0)
#>  rlang          1.1.4      2024-06-04 [1] CRAN (R 4.4.0)
#>  rmarkdown      2.27       2024-05-17 [1] CRAN (R 4.4.0)
#>  rpart          4.1.23     2023-12-05 [2] CRAN (R 4.4.1)
#>  rsample      * 1.2.1      2024-03-25 [1] CRAN (R 4.4.0)
#>  rstudioapi     0.16.0     2024-03-24 [1] CRAN (R 4.4.0)
#>  safetensors    0.1.2      2023-09-12 [1] CRAN (R 4.4.0)
#>  scales       * 1.3.0      2023-11-28 [1] CRAN (R 4.4.0)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.4.0)
#>  survival       3.6-4      2024-04-24 [2] CRAN (R 4.4.1)
#>  tabnet       * 0.6.0.9000 2024-08-09 [1] Github (mlverse/tabnet@c8c82d2)
#>  tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.4.0)
#>  tidymodels   * 1.2.0      2024-03-25 [1] CRAN (R 4.4.0)
#>  tidyr        * 1.3.1      2024-01-24 [1] CRAN (R 4.4.0)
#>  tidyselect     1.2.1      2024-03-11 [1] CRAN (R 4.4.0)
#>  timechange     0.3.0      2024-01-18 [1] CRAN (R 4.4.0)
#>  timeDate       4032.109   2023-12-14 [1] CRAN (R 4.4.0)
#>  torch          0.13.0     2024-05-21 [1] CRAN (R 4.4.0)
#>  tune         * 1.2.1      2024-04-18 [1] CRAN (R 4.4.0)
#>  utf8           1.2.4      2023-10-22 [1] CRAN (R 4.4.0)
#>  vctrs          0.6.5      2023-12-01 [1] CRAN (R 4.4.0)
#>  withr          3.0.1      2024-07-31 [1] CRAN (R 4.4.0)
#>  workflows    * 1.1.4      2024-02-19 [1] CRAN (R 4.4.0)
#>  workflowsets * 1.1.0      2024-03-21 [1] CRAN (R 4.4.0)
#>  xfun           0.46       2024-07-18 [1] CRAN (R 4.4.0)
#>  yaml           2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
#>  yardstick    * 1.3.1      2024-03-21 [1] CRAN (R 4.4.0)
#>  zeallot        0.1.0      2018-01-28 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Users/carlgoodwin/Library/R/arm64/4.4/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

@cgoo4 cgoo4 reopened this Aug 9, 2024
@cregouby
Copy link
Collaborator

Hello @cgoo4,

I added the message on purpose and it is maybe misleading. The meaning is 'tabnet model will be fit without using the case_weights variable.' as this is the actual usage of case_weights variable by tabnet, they are let appart for later-on usage by other downstream tydimodel packages.

Any proposal for a more informative message ?

@cgoo4
Copy link
Author

cgoo4 commented Aug 13, 2024

Hi @cregouby - Thank you for clarifying.

If it's possible to set case_weights more than one way, e.g. in a tidymodels workflow() and also in tabnet_fit(), then maybe the message could say the latter is being overridden by the former?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants