Skip to content

Commit

Permalink
address compute_grid_info() bug for partially regular grids
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch committed Nov 4, 2024
1 parent abe9fa6 commit 24ada80
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 10 deletions.
31 changes: 21 additions & 10 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -319,27 +319,39 @@ compute_grid_info <- function(workflow, grid) {
res <- min_grid(extract_spec_parsnip(workflow), grid)

if (any_parameters_preprocessor) {
res$.iter_preprocessor <- seq_len(nrow(res))
res$.iter_preprocessor <-
vctrs::vec_group_id(res[parameters_preprocessor$id])
attr(res$.iter_preprocessor, "n") <- NULL
} else {
res$.iter_preprocessor <- 1L
}

res$.msg_preprocessor <-
new_msgs_preprocessor(
seq_len(max(res$.iter_preprocessor)),
res$.iter_preprocessor,
max(res$.iter_preprocessor)
)

if (nrow(res) != nrow(grid) ||
(any_parameters_model && !any_parameters_preprocessor)) {
res$.iter_model <- seq_len(dplyr::n_distinct(res[parameters_model$id]))
res$.iter_model <- vctrs::vec_group_id(res[parameters_model$id])
attr(res$.iter_model, "n") <- NULL
} else {
res$.iter_model <- 1L
}

res$.iter_config <- list(list())
shift_submodels <- integer(length(unique(res$.iter_preprocessor)))
for (row in seq_len(nrow(res))) {
res$.iter_config[row] <- list(iter_config(res[row, ]))
res_row <- res[row, ]
iter_config <- iter_config(
res_row,
shift = shift_submodels[res_row$.iter_preprocessor]
)
shift_submodels[res_row$.iter_preprocessor] <-
shift_submodels[res_row$.iter_preprocessor] +
length(res_row$.submodels[[1]])
res$.iter_config[row] <- list(iter_config)
}

res$.msg_model <-
Expand All @@ -348,19 +360,18 @@ compute_grid_info <- function(workflow, grid) {
res
}

iter_config <- function(res_row) {
iter_config <- function(res_row, shift) {
submodels <- res_row$.submodels[[1]]
if (identical(submodels, list())) {
models <- res_row$.iter_model
} else {
models <- seq_len(length(submodels[[1]]) + 1)
model_configs <- res_row$.iter_model
if (!identical(submodels, list())) {
model_configs <- model_configs + seq_len(length(submodels[[1]]) + 1L) - 1
}

paste0(
"Preprocessor",
res_row$.iter_preprocessor,
"_Model",
format_with_padding(models)
format_with_padding(shift + model_configs)
)
}

Expand Down
112 changes: 112 additions & 0 deletions tests/testthat/test-grid_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,115 @@ test_that("compute_grid_info - recipe and model (with submodels)", {
)
expect_equal(nrow(res), 3)
})

test_that("compute_grid_info - recipe and model (with and without submodels)", {
library(workflows)
library(parsnip)
library(recipes)
library(dials)

rec <- recipe(mpg ~ ., mtcars) %>% step_spline_natural(deg_free = tune())
spec <- boost_tree(mode = "regression", trees = tune(), loss_reduction = tune())

wflow <- workflow()
wflow <- add_model(wflow, spec)
wflow <- add_recipe(wflow, rec)

# use grid_regular to (partially) trigger submodel trick
set.seed(1)
param_set <- extract_parameter_set_dials(wflow)
grid <- bind_rows(grid_regular(param_set), grid_space_filling(param_set))
res <- compute_grid_info(wflow, grid)

expect_equal(length(unique(res$.iter_preprocessor)), 5)
expect_equal(
unique(res$.msg_preprocessor),
paste0("preprocessor ", 1:5, "/5")
)
expect_equal(res$trees, c(rep(max(grid$trees), 10), 1))
expect_equal(res$.iter_model, c(rep(1:3, each = 3), 4, 5))
expect_equal(
res$.iter_config[1:3],
list(
c("Preprocessor1_Model1", "Preprocessor1_Model2", "Preprocessor1_Model3", "Preprocessor1_Model4"),
c("Preprocessor2_Model1", "Preprocessor2_Model2", "Preprocessor2_Model3"),
c("Preprocessor3_Model1", "Preprocessor3_Model2", "Preprocessor3_Model3")
)
)
expect_equal(res$.msg_model[1:3], paste0("preprocessor ", 1:3, "/5, model 1/5"))
expect_equal(
res$.submodels[1:3],
list(
list(trees = c(1L, 1000L, 1000L)),
list(trees = c(1L, 1000L)),
list(trees = c(1L, 1000L))
)
)
expect_named(
res,
c(".iter_preprocessor", ".msg_preprocessor", "deg_free", "trees",
"loss_reduction", ".iter_model", ".iter_config", ".msg_model", ".submodels"),
ignore.order = TRUE
)
expect_equal(nrow(res), 11)
})

test_that("compute_grid_info - model (with and without submodels)", {
library(workflows)
library(parsnip)
library(recipes)
library(dials)

rec <- recipe(mpg ~ ., mtcars)
spec <- mars(num_terms = tune(), prod_degree = tune(), prune_method = tune()) %>%
set_mode("classification") %>%
set_engine("earth")

wflow <- workflow()
wflow <- add_model(wflow, spec)
wflow <- add_recipe(wflow, rec)

set.seed(123)
params_grid <- grid_space_filling(
num_terms() %>% range_set(c(1L, 12L)),
prod_degree(),
prune_method(values = c("backward", "none", "forward")),
size = 7,
type = "latin_hypercube"
)

res <- compute_grid_info(wflow, params_grid)

expect_equal(res$.iter_preprocessor, rep(1, 5))
expect_equal(res$.msg_preprocessor, rep("preprocessor 1/1", 5))
expect_equal(length(unique(res$num_terms)), 5)
expect_equal(res$.iter_model, 1:5)
expect_equal(
res$.iter_config,
list(
c("Preprocessor1_Model1", "Preprocessor1_Model2"),
c("Preprocessor1_Model3", "Preprocessor1_Model4"),
"Preprocessor1_Model5", "Preprocessor1_Model6", "Preprocessor1_Model7"
)
)
expect_equal(
unique(res$.msg_model),
paste0("preprocessor 1/1, model ", 1:5,"/5")
)
expect_equal(
res$.submodels,
list(
list(num_terms = c(1)),
list(num_terms = c(3)),
list(), list(), list()
)
)
expect_named(
res,
c(".iter_preprocessor", ".msg_preprocessor", "num_terms", "prod_degree",
"prune_method", ".iter_model", ".iter_config", ".msg_model", ".submodels"),
ignore.order = TRUE
)
expect_equal(nrow(res), 5)
})

0 comments on commit 24ada80

Please sign in to comment.