Skip to content

Commit

Permalink
- refactor!: change method for renaming prediction columns (generalis…
Browse files Browse the repository at this point in the history
…e beyond BT / Euc)

And remove function`rename_prediction_cols()` since it is now redundant
  • Loading branch information
egouldo committed Sep 9, 2024
1 parent a51c526 commit 19e94e7
Showing 1 changed file with 27 additions and 45 deletions.
72 changes: 27 additions & 45 deletions R/convert_predictions.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,19 @@ convert_predictions <- function(augmented_data,
USE.NAMES = TRUE,
SIMPLIFY = "matrix"
)

vconvert_double_transformation <- Vectorize(
FUN = conversion_2,
vectorize.args = c("beta", "se"),
USE.NAMES = TRUE,
SIMPLIFY = TRUE
)

# Define input variables to conversion fns

key_var <- augmented_data %>%
dplyr::select(dplyr::contains(c("scenario", "SurveyID")))

beta_vals <- if (names(key_var) == "scenario") {
# Extracting yi estimates to supply to beta arg in vconvert
if (!pointblank::has_columns(augmented_data, "estimate")) {
Expand All @@ -67,9 +67,9 @@ convert_predictions <- function(augmented_data,
}
augmented_data$fit
}

# Apply Conversion w Helper Funs & Input Variables

converted <- if (transformation_type %nin% "double_transformation") {
vconvert(
beta = beta_vals,
Expand All @@ -79,12 +79,12 @@ convert_predictions <- function(augmented_data,
) %>%
t()
} else { # Back-transform response AND link-function

if (any(rlang::is_na(response_transformation), rlang::is_na(link_fun))) {
cli::cli_alert_warning("Missing Value for {.arg response_transformation}, returning {.val NA}")
out <- rlang::na_cpl
}

vconvert_double_transformation(
beta = beta_vals,
se = augmented_data$se.fit,
Expand All @@ -94,47 +94,29 @@ convert_predictions <- function(augmented_data,
) %>%
t()
}


names_lookup <- tibble(input_names = augmented_data %>%
colnames() %>%
discard(~ .x == colnames(key_var))) %>%
mutate(new_names =
case_match(input_names,
"fit" ~ "mean_origin",
"estimate" ~ "mean_origin",
"se.fit" ~ "se_origin",
"ci.low" ~ "lower",
"ci.hi" ~ "upper",
.default = NA)) %>%
rows_append(tibble(input_names = "sd.fit",
new_names = "sd_origin")) %>%
deframe()

# reshape conversion outputs

out <- converted %>%
dplyr::as_tibble() %>%
dplyr::mutate(across(.cols = everything(), .fns = as.double)) %>%
rename_prediction_cols(key_var, augmented_data) %>%
dplyr::bind_cols(key_var, .) %>%
dplyr::ungroup()
}

return(out)
}

#' Rename Prediction Columns
#' @description Renames the prediction columns in the output of convert_predictions
#' @param .data A tibble of out of sample predictions on the response variable scale
#' @param key_var A tibble of the key variables used in the conversion
#' @param .old_data A tibble of the original data used in the conversion
#' @return A tibble of out of sample predictions on the response variable scale with the correct column names
#' @family Analysis-level functions
#' @family Back-transformation
#' @importFrom data.table setnames
#' @importFrom purrr discard
#' @export
rename_prediction_cols <- function(.data, key_var, .old_data) {
if (names(key_var) == "scenario") {
.data %>%
data.table::setnames(
old = colnames(.data), # TODO: is this order correct? .old_data is being used for the new variable..
new = colnames(.old_data) %>% purrr::discard(~ .x == "scenario")
)
} else if (names(key_var) == "SurveyID") {
.data %>%
data.table::setnames(
old = colnames(.data), # TODO: is this order correct? .old_data is being used for the new variable..
new = colnames(.old_data) %>% purrr::discard(~ .x == "SurveyID")
)
} else {
NA
rename(., any_of(names_lookup)) %>%
bind_cols(key_var, .)
}

return(.data)
return(out)
}

0 comments on commit 19e94e7

Please sign in to comment.