Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Predictions in tune_grid for poisson regression models appear to be on the link scale (negative values) #956

Open
josho88 opened this issue Oct 31, 2024 · 5 comments
Labels
bug an unexpected problem or unintended behavior

Comments

@josho88
Copy link

josho88 commented Oct 31, 2024

The problem

When running tune_grid for glmnet poisson regression models the predictions returned by collect_predictions() appear to contain negative values (it looks as though they are on the link scale). This does not seem to be consistent with the behaviour of predict when fitting and evaluating models in the typical way (i.e., outside of tuning procedures).

It would seem as though these predictions on the link scale are then used to compute performance metrics by comparing against actual observations on the original response scale (i.e., counts)?

This seems to be inconsistent across cv resamples though; i.e., some folds seem to be on the link scale, other times the response scale (this isn't covered in the example below).

Reproducible example

library(tidyverse)
library(tidymodels)
library(glmnet)
#> Loading required package: Matrix
#> 
#> Attaching package: 'Matrix'
#> The following objects are masked from 'package:tidyr':
#> 
#>     expand, pack, unpack
#> Loaded glmnet 4.1-8

n <- 500

set.seed(123)
dat <- tibble(
  x1 = rnorm(n),
  x2 = rnorm(n),
  x3 = rnorm(n),
  x4 = rnorm(n),
  x5 = rnorm(n),
  y = rpois(n, 1)
)
wf <- workflow()

# split test and train ---------------------------------------------------------
set.seed(1839)
dat_split <- dat %>%
  initial_split()

dat_train <- training(dat_split)
dat_test <- testing(dat_split)

# preprocessing ----------------------------------------------------------------
dat_train_prep <- recipe(y ~ ., data = dat_train) %>%
  step_normalize(all_numeric(), -y)

wf <- wf %>%
  add_recipe(dat_train_prep)

# specify model ----------------------------------------------------------------
lasso_spec <- linear_reg(penalty = tune(), mixture = 1) %>%
  set_engine("glmnet", family = poisson)

# doesn't seem to make a difference using poisson_reg
# lasso_spec <- parsnip::poisson_reg(penalty = tune(), mixture = 1) %>%
#   set_engine("glmnet")

wf <- wf %>%
  add_model(lasso_spec)

rs <- vfold_cv(dat_train, v = 5)

# tune for lambda --------------------------------------------------------------
lasso_grid <- tune_grid(
  wf,
  resamples = rs,
  control =
    control_grid(
      save_pred = TRUE,
      extract = extract_model
    ),
  grid = grid_regular(penalty(), levels = 50)
)
#> → A | warning: `extract_model()` was deprecated in tune 0.1.6.
#>                ℹ Please use `extract_fit_engine()` instead.
#>                ℹ The deprecated feature was likely used in the tune package.
#>                  Please report the issue at <https://github.com/tidymodels/tune/issues>.
#> There were issues with some computations   A: x1
#> → B | warning: A correlation computation is required, but `estimate` is constant and has 0 standard deviation, resulting in a divide by 0 error. `NA` will be returned.
#> There were issues with some computations   A: x1There were issues with some computations   A: x1   B: x5

preds <- lasso_grid %>% collect_predictions() # get negative values here
preds
#> # A tibble: 18,750 × 6
#>    id      .pred  .row      penalty     y .config              
#>    <chr>   <dbl> <int>        <dbl> <int> <chr>                
#>  1 Fold1 -0.0136    26 0.0000000001     3 Preprocessor1_Model01
#>  2 Fold1 -0.279     32 0.0000000001     2 Preprocessor1_Model01
#>  3 Fold1 -0.128     37 0.0000000001     1 Preprocessor1_Model01
#>  4 Fold1 -0.0575    43 0.0000000001     1 Preprocessor1_Model01
#>  5 Fold1 -0.0824    44 0.0000000001     0 Preprocessor1_Model01
#>  6 Fold1 -0.244     45 0.0000000001     1 Preprocessor1_Model01
#>  7 Fold1 -0.347     48 0.0000000001     0 Preprocessor1_Model01
#>  8 Fold1 -0.0390    50 0.0000000001     1 Preprocessor1_Model01
#>  9 Fold1 -0.0264    55 0.0000000001     2 Preprocessor1_Model01
#> 10 Fold1 -0.0440    63 0.0000000001     0 Preprocessor1_Model01
#> # ℹ 18,740 more rows

Created on 2024-11-01 with reprex v2.1.1

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.1.3 (2022-03-10)
#>  os       macOS 13.4.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Europe/London
#>  date     2024-11-01
#>  pandoc   3.1.1 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package      * version    date (UTC) lib source
#>  assertthat     0.2.1      2019-03-21 [1] CRAN (R 4.1.0)
#>  backports      1.4.1      2021-12-13 [1] CRAN (R 4.1.1)
#>  broom        * 1.0.5      2023-06-09 [1] CRAN (R 4.1.3)
#>  cellranger     1.1.0      2016-07-27 [1] CRAN (R 4.1.0)
#>  class          7.3-20     2022-01-16 [1] CRAN (R 4.1.3)
#>  cli            3.6.1      2023-03-23 [1] CRAN (R 4.1.3)
#>  codetools      0.2-18     2020-11-04 [1] CRAN (R 4.1.3)
#>  colorspace     2.0-3      2022-02-21 [1] CRAN (R 4.1.1)
#>  crayon         1.5.1      2022-03-26 [1] CRAN (R 4.1.1)
#>  DBI            1.1.2      2021-12-20 [1] CRAN (R 4.1.1)
#>  dbplyr         2.1.1      2021-04-06 [1] CRAN (R 4.1.0)
#>  dials        * 1.2.0      2023-04-03 [1] CRAN (R 4.1.3)
#>  DiceDesign     1.9        2021-02-13 [1] CRAN (R 4.1.0)
#>  digest         0.6.29     2021-12-01 [1] CRAN (R 4.1.1)
#>  dplyr        * 1.1.2      2023-04-20 [1] CRAN (R 4.1.3)
#>  ellipsis       0.3.2      2021-04-29 [1] CRAN (R 4.1.0)
#>  evaluate       0.15       2022-02-18 [1] CRAN (R 4.1.1)
#>  fansi          1.0.3      2022-03-24 [1] CRAN (R 4.1.1)
#>  fastmap        1.1.1      2023-02-24 [1] CRAN (R 4.1.3)
#>  forcats      * 0.5.1      2021-01-27 [1] CRAN (R 4.1.1)
#>  foreach        1.5.2      2022-02-02 [1] CRAN (R 4.1.1)
#>  fs             1.5.2      2021-12-08 [1] CRAN (R 4.1.1)
#>  furrr          0.2.3      2021-06-25 [1] CRAN (R 4.1.0)
#>  future         1.24.0     2022-02-19 [1] CRAN (R 4.1.1)
#>  future.apply   1.8.1      2021-08-10 [1] CRAN (R 4.1.1)
#>  generics       0.1.3      2022-07-05 [1] CRAN (R 4.1.3)
#>  ggplot2      * 3.4.3      2023-08-14 [1] CRAN (R 4.1.3)
#>  glmnet       * 4.1-8      2023-08-22 [1] CRAN (R 4.1.3)
#>  globals        0.14.0     2020-11-22 [1] CRAN (R 4.1.0)
#>  glue           1.6.2      2022-02-24 [1] CRAN (R 4.1.1)
#>  gower          1.0.0      2022-02-03 [1] CRAN (R 4.1.1)
#>  GPfit          1.0-8      2019-02-08 [1] CRAN (R 4.1.0)
#>  gtable         0.3.0      2019-03-25 [1] CRAN (R 4.1.1)
#>  hardhat        1.3.0      2023-03-30 [1] CRAN (R 4.1.3)
#>  haven          2.4.3      2021-08-04 [1] CRAN (R 4.1.1)
#>  hms            1.1.3      2023-03-21 [1] CRAN (R 4.1.3)
#>  htmltools      0.5.8.1    2024-04-04 [1] CRAN (R 4.1.3)
#>  httr           1.4.2      2020-07-20 [1] CRAN (R 4.1.0)
#>  infer        * 1.0.4      2022-12-02 [1] CRAN (R 4.1.3)
#>  ipred          0.9-12     2021-09-15 [1] CRAN (R 4.1.1)
#>  iterators      1.0.14     2022-02-05 [1] CRAN (R 4.1.1)
#>  jsonlite       1.8.0      2022-02-22 [1] CRAN (R 4.1.1)
#>  knitr          1.48       2024-07-07 [1] CRAN (R 4.1.3)
#>  lattice        0.20-45    2021-09-22 [1] CRAN (R 4.1.3)
#>  lava           1.6.10     2021-09-02 [1] CRAN (R 4.1.1)
#>  lhs            1.1.5      2022-03-22 [1] CRAN (R 4.1.1)
#>  lifecycle      1.0.3      2022-10-07 [1] CRAN (R 4.1.3)
#>  listenv        0.8.0      2019-12-05 [1] CRAN (R 4.1.0)
#>  lubridate      1.9.3      2023-09-27 [1] CRAN (R 4.1.3)
#>  magrittr       2.0.3      2022-03-30 [1] CRAN (R 4.1.1)
#>  MASS           7.3-55     2022-01-16 [1] CRAN (R 4.1.3)
#>  Matrix       * 1.4-0      2021-12-08 [1] CRAN (R 4.1.3)
#>  modeldata    * 1.2.0      2023-08-09 [1] CRAN (R 4.1.3)
#>  modelr         0.1.8      2020-05-19 [1] CRAN (R 4.1.0)
#>  munsell        0.5.0      2018-06-12 [1] CRAN (R 4.1.0)
#>  nnet           7.3-17     2022-01-16 [1] CRAN (R 4.1.3)
#>  parallelly     1.31.0     2022-04-07 [1] CRAN (R 4.1.1)
#>  parsnip      * 1.1.1      2023-08-17 [1] CRAN (R 4.1.3)
#>  pillar         1.9.0      2023-03-22 [1] CRAN (R 4.1.3)
#>  pkgconfig      2.0.3      2019-09-22 [1] CRAN (R 4.1.0)
#>  prodlim        2019.11.13 2019-11-17 [1] CRAN (R 4.1.0)
#>  purrr        * 1.0.2      2023-08-10 [1] CRAN (R 4.1.3)
#>  R.cache        0.16.0     2022-07-21 [1] CRAN (R 4.1.1)
#>  R.methodsS3    1.8.2      2022-06-13 [1] CRAN (R 4.1.1)
#>  R.oo           1.25.0     2022-06-12 [1] CRAN (R 4.1.1)
#>  R.utils        2.12.2     2022-11-11 [1] CRAN (R 4.1.1)
#>  R6             2.5.1      2021-08-19 [1] CRAN (R 4.1.1)
#>  Rcpp           1.0.10     2023-01-22 [1] CRAN (R 4.1.1)
#>  readr        * 2.1.2      2022-01-30 [1] CRAN (R 4.1.1)
#>  readxl         1.4.0      2022-03-28 [1] CRAN (R 4.1.1)
#>  recipes      * 1.0.8      2023-08-25 [1] CRAN (R 4.1.3)
#>  reprex         2.1.1      2024-07-06 [1] CRAN (R 4.1.3)
#>  rlang          1.1.1      2023-04-28 [1] CRAN (R 4.1.3)
#>  rmarkdown      2.28       2024-08-17 [1] CRAN (R 4.1.3)
#>  rpart          4.1.16     2022-01-24 [1] CRAN (R 4.1.3)
#>  rsample      * 1.2.0      2023-08-23 [1] CRAN (R 4.1.3)
#>  rstudioapi     0.15.0     2023-07-07 [1] CRAN (R 4.1.3)
#>  rvest          1.0.2      2021-10-16 [1] CRAN (R 4.1.1)
#>  scales       * 1.2.1      2022-08-20 [1] CRAN (R 4.1.3)
#>  sessioninfo    1.2.2      2021-12-06 [1] CRAN (R 4.1.1)
#>  shape          1.4.6      2021-05-19 [1] CRAN (R 4.1.0)
#>  stringi        1.7.6      2021-11-29 [1] CRAN (R 4.1.1)
#>  stringr      * 1.5.0      2022-12-02 [1] CRAN (R 4.1.3)
#>  styler         1.9.0      2023-01-15 [1] CRAN (R 4.1.1)
#>  survival       3.2-13     2021-08-24 [1] CRAN (R 4.1.3)
#>  tibble       * 3.2.1      2023-03-20 [1] CRAN (R 4.1.3)
#>  tidymodels   * 1.1.1      2023-08-24 [1] CRAN (R 4.1.3)
#>  tidyr        * 1.3.0      2023-01-24 [1] CRAN (R 4.1.3)
#>  tidyselect     1.2.0      2022-10-10 [1] CRAN (R 4.1.3)
#>  tidyverse    * 1.3.1      2021-04-15 [1] CRAN (R 4.1.0)
#>  timechange     0.3.0      2024-01-18 [1] CRAN (R 4.1.3)
#>  timeDate       3043.102   2018-02-21 [1] CRAN (R 4.1.0)
#>  tune         * 1.1.2      2023-08-23 [1] CRAN (R 4.1.3)
#>  tzdb           0.4.0      2023-05-12 [1] CRAN (R 4.1.3)
#>  utf8           1.2.2      2021-07-24 [1] CRAN (R 4.1.0)
#>  vctrs          0.6.3      2023-06-14 [1] CRAN (R 4.1.3)
#>  withr          2.5.0      2022-03-03 [1] CRAN (R 4.1.1)
#>  workflows    * 1.1.3      2023-02-22 [1] CRAN (R 4.1.3)
#>  workflowsets * 1.0.1      2023-04-06 [1] CRAN (R 4.1.3)
#>  xfun           0.44       2024-05-15 [1] CRAN (R 4.1.3)
#>  xml2           1.3.3      2021-11-30 [1] CRAN (R 4.1.1)
#>  yaml           2.3.5      2022-02-21 [1] CRAN (R 4.1.1)
#>  yardstick    * 1.2.0      2023-04-21 [1] CRAN (R 4.1.3)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.1-arm64/Resources/library
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Note, the code below should demonstrate how metrics appear to be calculated using predictions on the link scale. Normal predict behaviour doesn't seem to be to use type = "raw".

# extract metrics
met <- lasso_grid %>% collect_metrics(summarize = FALSE)

# extract predictions for fold 1
fold1 <- lasso_grid[1, ]

preds_fold1 <- fold1 %>%
  unnest(.predictions) %>%
  select(.pred:y)

# metrics seem to be computed by comparing predictions on link scale with observations on response scale?
preds_fold1 %>%
  filter(penalty == 1e-10) %>%
  metrics(truth = y, estimate = .pred)
#> # A tibble: 3 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 rmse    standard    1.57    
#> 2 rsq     standard    0.000867
#> 3 mae     standard    1.24

# fold1 metrics - match those created above
met[1:2, ]
#> # A tibble: 2 × 6
#>   id         penalty .metric .estimator .estimate .config              
#>   <chr>        <dbl> <chr>   <chr>          <dbl> <chr>                
#> 1 Fold1 0.0000000001 rmse    standard    1.57     Preprocessor1_Model01
#> 2 Fold1 0.0000000001 rsq     standard    0.000867 Preprocessor1_Model01

# we get different results if predictions are converted to response scale (as they seem to be in fit_resamples and other predict methods)
preds_fold1 %>%
  filter(penalty == 1e-10) %>%
  mutate(.pred = exp(.pred)) %>%
  metrics(truth = y, estimate = .pred)
#> # A tibble: 3 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 rmse    standard    0.989   
#> 2 rsq     standard    0.000374
#> 3 mae     standard    0.792

# what about when fitting a workflow manually to fold 1?
train <- rs$splits[[1]] %>% analysis()
test <- rs$splits[[1]] %>% assessment()

m <- workflows::extract_spec_parsnip(wf)
m <- update(
  m,
  penalty = 1e-10
)

# add updated model to workflow object
wf <- workflows::update_model(wf, spec = m)

# fit and generate predictions
# these do not appear to be negative
# i.e., normal behaviour seems to be to return on the response scale
wf_fit <- wf %>% fit(train)
wf_fit %>% predict(test)
#> # A tibble: 75 × 1
#>    .pred
#>    <dbl>
#>  1 0.987
#>  2 0.756
#>  3 0.879
#>  4 0.944
#>  5 0.921
#>  6 0.783
#>  7 0.707
#>  8 0.962
#>  9 0.974
#> 10 0.957
#> # ℹ 65 more rows

# setting type = "raw" generates negatives
wf_fit %>%
  predict(test, type = "raw") %>%
  head()
#>               s1
#> [1,] -0.01355497
#> [2,] -0.27939017
#> [3,] -0.12843747
#> [4,] -0.05751146
#> [5,] -0.08241273
#> [6,] -0.24400376
@hfrick
Copy link
Member

hfrick commented Nov 1, 2024

I think this is a result of this bug in poissonreg because tune_grid() uses multi_predict() where possible tidymodels/poissonreg#63

@josho88
Copy link
Author

josho88 commented Nov 3, 2024

Ah ok, yes makes sense. Presume that bug also affects linear_reg() %>% set_engine("glmnet", family = poisson) then?

@hfrick
Copy link
Member

hfrick commented Nov 13, 2024

glmnet allows you to use a family from base-R like you did in linear_reg() %>% set_engine("glmnet", family = poisson) but is usually faster when using glmnet's own implementation. You can choose that by setting the family argument to "poisson", so just the string instead of the poisson family object. If you set family to the string, glmnet will return an object classed as fishnet to indicate that it's a poisson regression at heart. In that case, the poissonreg package will know what to do with it. I've just merged a fix for that bug into the dev version.

So to get correct results, you'd need to

  • install the dev version of poissonreg via pak::pak("tidymodels/poissonreg")
  • load poissonreg at the top of your script via library(poissonreg) to get the correct multi_predict() method for fishnet models
  • set the family to "poisson", i.e. to the string, not the base-R family object
    (Being a bit on the verbose side here, in case someone else runs into this.)

If you use a base-R family, the resulting object is one with a "catch-all" class of glmnetfit and parsnip can't infer from that class that you'd expect it to serve up the counts rather than the linear predictor in this case here.

@hfrick
Copy link
Member

hfrick commented Nov 13, 2024

Ah, and if you directly use poisson_reg() %>% set_engine("glmnet") it will set family = "poisson" for you. That's probably the better way since it makes it more obvious that you need the poissonreg package.

@topepo
Copy link
Member

topepo commented Nov 13, 2024

We might issue a warning when family is set to "poisson" in linear_reg()

@topepo topepo added the bug an unexpected problem or unintended behavior label Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug an unexpected problem or unintended behavior
Projects
None yet
Development

No branches or pull requests

3 participants