Skip to content

Commit

Permalink
refactor compute_grid_info() (#957)
Browse files Browse the repository at this point in the history
  • Loading branch information
simonpcouch authored Nov 1, 2024
1 parent 12faa6d commit abe9fa6
Showing 1 changed file with 44 additions and 251 deletions.
295 changes: 44 additions & 251 deletions R/grid_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -316,19 +316,52 @@ compute_grid_info <- function(workflow, grid) {
any_parameters_model <- nrow(parameters_model) > 0
any_parameters_preprocessor <- nrow(parameters_preprocessor) > 0

if (any_parameters_model) {
if (any_parameters_preprocessor) {
compute_grid_info_model_and_preprocessor(workflow, grid, parameters_model)
} else {
compute_grid_info_model(workflow, grid, parameters_model)
}
res <- min_grid(extract_spec_parsnip(workflow), grid)

if (any_parameters_preprocessor) {
res$.iter_preprocessor <- seq_len(nrow(res))
} else {
if (any_parameters_preprocessor) {
compute_grid_info_preprocessor(workflow, grid, parameters_model)
} else {
rlang::abort("Internal error: `workflow` should have some tunable parameters if `grid` is not `NULL`.")
}
res$.iter_preprocessor <- 1L
}

res$.msg_preprocessor <-
new_msgs_preprocessor(
seq_len(max(res$.iter_preprocessor)),
max(res$.iter_preprocessor)
)

if (nrow(res) != nrow(grid) ||
(any_parameters_model && !any_parameters_preprocessor)) {
res$.iter_model <- seq_len(dplyr::n_distinct(res[parameters_model$id]))
} else {
res$.iter_model <- 1L
}

res$.iter_config <- list(list())
for (row in seq_len(nrow(res))) {
res$.iter_config[row] <- list(iter_config(res[row, ]))
}

res$.msg_model <-
new_msgs_model(i = res$.iter_model, n = max(res$.iter_model), res$.msg_preprocessor)

res
}

iter_config <- function(res_row) {
submodels <- res_row$.submodels[[1]]
if (identical(submodels, list())) {
models <- res_row$.iter_model
} else {
models <- seq_len(length(submodels[[1]]) + 1)
}

paste0(
"Preprocessor",
res_row$.iter_preprocessor,
"_Model",
format_with_padding(models)
)
}

# This generates a "dummy" grid_info object that has the same
Expand Down Expand Up @@ -360,217 +393,6 @@ new_grid_info_resamples <- function() {
out
}

compute_grid_info_preprocessor <- function(workflow,
grid,
parameters_model) {
out <- grid

n_preprocessors <- nrow(out)
seq_preprocessors <- seq_len(n_preprocessors)

# Preprocessor<i>_Model1
ids <- format_with_padding(seq_preprocessors)
iter_configs <- paste0("Preprocessor", ids, "_Model1")
iter_configs <- as.list(iter_configs)

# preprocessor <i>/<n>
msgs_preprocessor <- new_msgs_preprocessor(
i = seq_preprocessors,
n = n_preprocessors
)

# preprocessor <i>/<n>, model 1/1
msgs_model <- new_msgs_model(
i = 1L,
n = 1L,
msgs_preprocessor = msgs_preprocessor
)

# Manually add .submodels column, which will always have empty lists
submodels <- rep_len(list(list()), n_preprocessors)

out <- tibble::add_column(
.data = out,
.iter_preprocessor = seq_preprocessors,
.before = 1L
)

out <- tibble::add_column(
.data = out,
.msg_preprocessor = msgs_preprocessor,
.after = ".iter_preprocessor"
)

# Add at the end
out <- tibble::add_column(
.data = out,
.iter_model = 1L,
.after = NULL
)

out <- tibble::add_column(
.data = out,
.iter_config = iter_configs,
.after = ".iter_model"
)

out <- tibble::add_column(
.data = out,
.msg_model = msgs_model,
.after = ".iter_config"
)

out <- tibble::add_column(
.data = out,
.submodels = submodels,
.after = ".msg_model"
)

out
}

compute_grid_info_model <- function(workflow,
grid,
parameters_model) {
spec <- extract_spec_parsnip(workflow)
out <- min_grid(spec, grid)

n_fit_models <- nrow(out)
seq_fit_models <- seq_len(n_fit_models)

# preprocessor 1/1
msgs_preprocessor <- new_msgs_preprocessor(i = 1L, n = 1L)
msgs_preprocessor <- rep(msgs_preprocessor, times = n_fit_models)

# preprocessor 1/1, model <i_fit>/<n_fit>
msgs_model <- new_msgs_model(
i = seq_fit_models,
n = n_fit_models,
msgs_preprocessor = msgs_preprocessor
)

# Preprocessor1_Model<i>
iter_configs <- compute_config_ids(out, "Preprocessor1")

out <- tibble::add_column(
.data = out,
.iter_preprocessor = 1L,
.before = 1L
)

out <- tibble::add_column(
.data = out,
.msg_preprocessor = msgs_preprocessor,
.after = ".iter_preprocessor"
)

out <- tibble::add_column(
.data = out,
.iter_model = seq_fit_models,
.after = ".msg_preprocessor"
)

out <- tibble::add_column(
.data = out,
.iter_config = iter_configs,
.after = ".iter_model"
)

out <- tibble::add_column(
.data = out,
.msg_model = msgs_model,
.after = ".iter_config"
)

out
}

compute_grid_info_model_and_preprocessor <- function(workflow,
grid,
parameters_model) {
parameter_names_model <- parameters_model[["id"]]

# Nest model parameters, keep preprocessor parameters outside
out <- tidyr::nest(grid, data = dplyr::all_of(parameter_names_model))

n_preprocessors <- nrow(out)
seq_preprocessors <- seq_len(n_preprocessors)

# preprocessor <i_pre>/<n_pre>
msgs_preprocessor <- new_msgs_preprocessor(
i = seq_preprocessors,
n = n_preprocessors
)

out <- tibble::add_column(
.data = out,
.iter_preprocessor = seq_preprocessors,
.before = 1L
)

out <- tibble::add_column(
.data = out,
.msg_preprocessor = msgs_preprocessor,
.after = ".iter_preprocessor"
)

spec <- extract_spec_parsnip(workflow)

ids_preprocessor <- format_with_padding(seq_preprocessors)
ids_preprocessor <- paste0("Preprocessor", ids_preprocessor)

model_grids <- out[["data"]]

for (i in seq_preprocessors) {
model_grid <- model_grids[[i]]

model_grid <- min_grid(spec, model_grid)

n_fit_models <- nrow(model_grid)
seq_fit_models <- seq_len(n_fit_models)

msg_preprocessor <- msgs_preprocessor[[i]]
id_preprocessor <- ids_preprocessor[[i]]

# preprocessor <i_pre>/<n_pre>, model <i_mod>/<n_mod>
msgs_model <- new_msgs_model(
i = seq_fit_models,
n = n_fit_models,
msgs_preprocessor = msg_preprocessor
)

# Preprocessor<i_pre>_Model<i>
iter_configs <- compute_config_ids(model_grid, id_preprocessor)

model_grid <- tibble::add_column(
.data = model_grid,
.iter_model = seq_fit_models,
.before = 1L
)

model_grid <- tibble::add_column(
.data = model_grid,
.iter_config = iter_configs,
.after = ".iter_model"
)

model_grid <- tibble::add_column(
.data = model_grid,
.msg_model = msgs_model,
.after = ".iter_config"
)

model_grids[[i]] <- model_grid
}

out[["data"]] <- model_grids

# Unnest to match other grid-info generators
out <- tidyr::unnest(out, data)

out
}

new_msgs_preprocessor <- function(i, n) {
paste0("preprocessor ", i, "/", n)
}
Expand All @@ -583,35 +405,6 @@ format_with_padding <- function(x) {
gsub(" ", "0", format(x))
}

compute_config_ids <- function(data, id_preprocessor) {
submodels <- unnest(data, .submodels, keep_empty = TRUE)
submodels <- pull(submodels, .submodels)

# Current model that actually is fit is not included in the submodel count
# so we add 1
model_sizes <- lengths(submodels) + 1L

n_total_models <- sum(model_sizes)

ids <- format_with_padding(seq_len(n_total_models))
ids <- paste0(id_preprocessor, "_Model", ids)

n_fit_models <- nrow(data)

out <- vector("list", length = n_fit_models)

start <- 1L

for (i in seq_len(n_fit_models)) {
size <- model_sizes[[i]]
stop <- start + size - 1L
out[[i]] <- ids[rlang::seq2(start, stop)]
start <- stop + 1L
}

out
}

# ------------------------------------------------------------------------------

has_preprocessor <- function(workflow) {
Expand Down

0 comments on commit abe9fa6

Please sign in to comment.