From 1c9da3cfe5d99764195bdd04a83337c092e05c1e Mon Sep 17 00:00:00 2001 From: olivroy Date: Mon, 12 Aug 2024 12:18:49 -0400 Subject: [PATCH] Support more than 1 group in `summarise_with_total()` --- NEWS.md | 2 ++ R/dplyr-plus.R | 50 +++++++++++++++++--------------- tests/testthat/test-dplyr-plus.R | 18 ++++++++++++ 3 files changed, 46 insertions(+), 24 deletions(-) diff --git a/NEWS.md b/NEWS.md index d272ded..6890025 100644 --- a/NEWS.md +++ b/NEWS.md @@ -50,6 +50,8 @@ that will passed on to `proj_list()` * `active_rs_doc()` returns the relative path if in RStudio project. +* `summarise_with_total()` works with more than 1 group to get the total summary. + # reuseme 0.0.2 * `complete_todo()` no longer deletes the full line. It only deletes what it says it deletes (#27). diff --git a/R/dplyr-plus.R b/R/dplyr-plus.R index 51bd667..47bb692 100644 --- a/R/dplyr-plus.R +++ b/R/dplyr-plus.R @@ -362,7 +362,7 @@ extract_cell_value <- function(data, var, filter, name = NULL, length = NULL, un } # summarise with total --------------------------------------------------------- -#' Compute a summary for one group with the total included. +#' Compute a summary for groups with the total included. #' #' This function is useful to create end tables, apply the same formula to a group and to its overall. #' You can specify a personalized `Total` value with the `.label` argument. You @@ -397,48 +397,50 @@ extract_cell_value <- function(data, var, filter, name = NULL, length = NULL, un #' ) summarise_with_total <- function(.data, ..., .by = NULL, .label = "Total", .first = TRUE) { check_string(.label) + # check_dots_used() # Computing summary (depending if .data is grouped or uses `.by`) if (dplyr::is_grouped_df(.data)) { group_var <- dplyr::group_vars(.data) - if (length(group_var) != 1) { - cli::cli_abort(c( - "Must supply a single group" - )) - } - by_summary <- dplyr::summarise(.data, ...) summary <- dplyr::summarise( - .data = dplyr::ungroup(.data), - "{group_var}" := .label, + dplyr::ungroup(.data), + dplyr::across( + dplyr::all_of(group_var), + function(x) .label + ), ... ) } else { # compute summary by variable by_summary <- dplyr::summarise(.data, ..., .by = {{ .by }}) - # Compute the summary for total - summary <- dplyr::summarise(.data, "{{ .by }}" := .label, ...) + summary <- dplyr::summarise(.data, dplyr::across({{ .by }}, function(x) .label), ...) } - # Decide how to arrange the data. - summary_levels <- if (.first) { - c(.label, as.character(levels(by_summary[[1]]) %||% unique(by_summary[[1]]))) - } else { - c(as.character(levels(by_summary[[1]]) %||% unique(by_summary[[1]])), .label) - } + # Figure out which columns are the total column. + total_cols <- which(purrr::map_lgl(summary, function(x) x == .label)) - if (is.factor(by_summary[[1]])) { - by_summary[[1]] <- factor(by_summary[[1]], levels = summary_levels) - summary[[1]] <- factor(summary[[1]], levels = summary_levels) - } else if (!is.character(by_summary[[1]])) { - by_summary[[1]] <- factor(by_summary[[1]], levels = summary_levels) - summary[[1]] <- factor(summary[[1]], levels = summary_levels) - } + for (i in seq_along(total_cols)) { + # Decide how to arrange the data. + col_id <- names(total_cols)[[i]] + summary_levels <- if (.first) { + c(.label, as.character(levels(.data[[col_id]]) %||% unique(.data[[col_id]]))) + } else { + c(as.character(levels(.data[[col_id]]) %||% unique(.data[[col_id]])), .label) + } + if (is.factor(by_summary[[col_id]])) { + by_summary[[col_id]] <- factor(by_summary[[col_id]], levels = summary_levels) + summary[[col_id]] <- factor(summary[[col_id]], levels = summary_levels) + } else if (!is.character(by_summary[[col_id]])) { + by_summary[[col_id]] <- factor(by_summary[[col_id]], levels = summary_levels) + summary[[col_id]] <- factor(summary[[col_id]], levels = summary_levels) + } + } # .first decides which ones to bind if (.first) { diff --git a/tests/testthat/test-dplyr-plus.R b/tests/testthat/test-dplyr-plus.R index 06206a1..12f87aa 100644 --- a/tests/testthat/test-dplyr-plus.R +++ b/tests/testthat/test-dplyr-plus.R @@ -125,6 +125,24 @@ test_that("summarise_with_total() works", { }) }) +test_that("summarise_with_total() works with two groups", { + gr_s <- summarise_with_total(dplyr::group_by(mtcars, vs, cyl), mpg = sum(mpg)) + by_s <- summarise_with_total(mtcars, mpg = sum(mpg), .by = c(vs, cyl)) + + expect_equal( + dim(gr_s), + dim(by_s) + ) + expect_setequal( + levels(by_s$cyl), + levels(gr_s$cyl) + ) + expect_setequal( + levels(by_s$vs), + levels(gr_s$vs) + ) +}) + test_that("summarise_with_total() keeps factors", { fac <- mtcars |> dplyr::mutate(vs = factor(vs), mpg, .keep = "none")