Skip to content

Commit

Permalink
add_chunk() uses VALUES statement to compute the number of rows in ea…
Browse files Browse the repository at this point in the history
…ch chunk
  • Loading branch information
jarodmeng committed Jan 13, 2023
1 parent 786b86a commit 47eb171
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 40 deletions.
41 changes: 25 additions & 16 deletions R/chunk.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,26 @@
#' size limit on any discrete INSERT INTO statement.
#'
#' @param value The original data frame.
#' @param chunk_size Maximum size (in bytes) of each unique chunk. Default to
#' 750,000 bytes.
#' @param chunk_fields A character vector of existing field names that are used
#' to split the data frame.
#' @param base_chunk_fields A character vector of existing field names that are
#' used to split the data frame before checking the chunk size.
#' @param chunk_size Maximum size (in bytes) of the VALUES statement encoding
#' each unique chunk. Default to 1,000,000 bytes (i.e. 1Mb).
#' @param new_chunk_field_name A string indicating the new chunk field name.
#' Default to "chunk".
#' Default to "aux_chunk_idx".
#' @importFrom rlang :=
#' @export
#' @examples
#' \dontrun{
#' # returns the original data frame because it's within size
#' add_chunk(iris)
#' # add a new chunk_idx field
#' # add a new aux_chunk_idx field
#' add_chunk(iris, chunk_size = 2000)
#' # the new chunk_idx field is added on top of Species
#' add_chunk(iris, chunk_size = 2000, chunk_fields = c("Species"))
#' # the new aux_chunk_idx field is added on top of Species
#' add_chunk(iris, chunk_size = 2000, base_chunk_fields = c("Species"))
#' }
add_chunk <- function(
value, chunk_size = 7.5e5,
chunk_fields = NULL, new_chunk_field_name = "chunk_idx"
value, base_chunk_fields = NULL, chunk_size = 1e6,
new_chunk_field_name = "aux_chunk_idx"
) {
.add_chunk <- function(value, start = 1L) {
if (new_chunk_field_name %in% colnames(value)) {
Expand All @@ -41,17 +41,23 @@ add_chunk <- function(
call. = FALSE
)
}
n_chunks <- (as.integer(utils::object.size(value)) %/% chunk_size) + 1
chunk_size <- nrow(value) %/% n_chunks
sample_value <- dplyr::slice(
value, sample(1:nrow(value), 100, replace = TRUE)
)
sample_value_query_size <- utils::object.size(
.create_values_statement(dummyPrestoConnection(), sample_value)
)
avg_row_query_size = as.integer(sample_value_query_size)/100
n_rows_per_chunk <- chunk_size %/% avg_row_query_size
dplyr::mutate(
dplyr::ungroup(value),
!!rlang::sym(new_chunk_field_name) :=
start + as.integer((dplyr::row_number() - 1L) %/% chunk_size)
start + as.integer((dplyr::row_number() - 1L) %/% n_rows_per_chunk)
)
}

if (!is.null(chunk_fields)) {
split_values <- dplyr::group_split(value, !!!rlang::syms(chunk_fields))
if (!is.null(base_chunk_fields)) {
split_values <- dplyr::group_split(value, !!!rlang::syms(base_chunk_fields))
start <- 0L
res <- vector(mode = "list", length = length(split_values))
for (i in seq_along(res)) {
Expand All @@ -65,7 +71,10 @@ add_chunk <- function(
return(dplyr::bind_rows(res))
}
} else {
if (utils::object.size(value) <= chunk_size) {
value_query_size <- utils::object.size(
.create_values_statement(dummyPrestoConnection(), value)
)
if (value_query_size <= chunk_size) {
return(value)
} else {
return(.add_chunk(value))
Expand Down
13 changes: 8 additions & 5 deletions R/dbWriteTable.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,10 @@ NULL
{
if (!found || overwrite) {
if (use.one.query) {
sql_values <- DBI::sqlData(conn, value)
fields <- DBI::dbQuoteIdentifier(conn, names(sql_values))
rows <- do.call(paste, c(unname(sql_values), sep = ", "))
fields <- DBI::dbQuoteIdentifier(conn, colnames(value))
sql <- DBI::SQL(paste0(
"SELECT * FROM (\n",
"VALUES\n",
paste0(" (", rows, ")", collapse = ",\n"),
.create_values_statement(conn, value),
") AS t (", paste(fields, collapse = ", "), ")\n"
))
dbCreateTableAs(
Expand Down Expand Up @@ -180,3 +177,9 @@ setMethod(
signature("PrestoConnection", "ANY", "data.frame"),
.dbWriteTable
)

.create_values_statement <- function(conn, value, row.names = FALSE) {
sql_values <- DBI::sqlData(conn, value, row.names)
rows <- do.call(paste, c(unname(sql_values), sep = ", "))
DBI::SQL(paste0("VALUES\n", paste0(" (", rows, ")", collapse = ",\n")))
}
22 changes: 11 additions & 11 deletions man/add_chunk.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions tests/testthat/test-add_chunk.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ test_that("add_chunk returns the original data frame if within size", {
iris
)
expect_equal_data_frame(
add_chunk(iris, chunk_fields = c("Species")),
add_chunk(iris, base_chunk_fields = c("Species")),
iris
)
})
Expand All @@ -23,9 +23,9 @@ test_that("add_chunk adds a new field when larger than size limit", {
chunk_iris <- add_chunk(iris, chunk_size = 2000)
expect_equal(
colnames(chunk_iris),
c(colnames(iris), "chunk_idx")
c(colnames(iris), "aux_chunk_idx")
)
expect_equal(class(chunk_iris$chunk_idx), "integer")
expect_equal(class(chunk_iris$aux_chunk_idx), "integer")
chunk_iris_2 <-
add_chunk(iris, chunk_size = 2000, new_chunk_field_name = "chunk")
expect_equal(
Expand All @@ -34,14 +34,14 @@ test_that("add_chunk adds a new field when larger than size limit", {
)
expect_equal(class(chunk_iris_2$chunk), "integer")
chunk_iris_field <-
add_chunk(iris, chunk_size = 2000, chunk_fields = c("Species"))
add_chunk(iris, chunk_size = 2000, base_chunk_fields = c("Species"))
expect_equal(
colnames(chunk_iris_field),
c(colnames(iris), "chunk_idx")
c(colnames(iris), "aux_chunk_idx")
)
expect_equal(class(chunk_iris_field$chunk_idx), "integer")
expect_equal(class(chunk_iris_field$aux_chunk_idx), "integer")
expect_equal(
nrow(dplyr::count(chunk_iris_field, Species, chunk_idx)),
6L
nrow(dplyr::count(chunk_iris_field, Species, aux_chunk_idx)),
9L
)
})

0 comments on commit 47eb171

Please sign in to comment.