Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify estimate_weight function #129

Merged
merged 7 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 23 additions & 7 deletions R/matching.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#' procedure will not be triggered, and hence the element `"boot"` of output list object will be NULL.
#' @param set_seed_boot a scalar, the random seed for conducting the bootstrapping, only relevant if
#' \code{n_boot_iteration} is not NULL. By default, use seed 1234
#' @param boot_strata a character vector of column names in \code{data} that defines the strata for bootstrapping.
#' This ensures that samples are drawn proportionally from each defined stratum. If \code{NULL},
#' no stratification during bootstrapping process. By default, it is "ARM"
#' @param ... Additional `control` parameters passed to [stats::optim].
#'
#' @return a list with the following 4 elements,
Expand All @@ -29,13 +32,15 @@
#' modifiers}
#' \item{ess}{effective sample size, square of sum divided by sum of squares}
#' \item{opt}{R object returned by \code{base::optim()}, for assess convergence and other details}
#' \item{boot_strata}{'strata' from a boot::boot object}
#' \item{boot_seed}{column names in \code{data} of the stratification factors}
#' \item{boot}{a n by 2 by k array or NA, where n equals to number of rows in \code{data}, and k equals
#' \code{n_boot_iteration}. The 2 columns in the second dimension include a column of numeric indexes of the rows
#' in \code{data} that are selected at a bootstrapping iteration and a column of weights. \code{boot} is NA when
#' argument \code{n_boot_iteration} is set as NULL
#' }
#' }
#'
#' @importFrom boot boot
#' @examples
#' data(centered_ipd_sat)
#' centered_colnames <- grep("_CENTERED", colnames(centered_ipd_sat), value = TRUE)
Expand All @@ -55,6 +60,7 @@ estimate_weights <- function(data,
method = "BFGS",
n_boot_iteration = NULL,
set_seed_boot = 1234,
boot_strata = "ARM",
...) {
# pre check
ch1 <- is.data.frame(data)
Expand Down Expand Up @@ -87,6 +93,13 @@ estimate_weights <- function(data,
))
}

if (!is.null(boot_strata)) {
ch4 <- boot_strata %in% names(data)
if (!all(ch4)) {
stop("Some variables in boot_strata are not in data: ", toString(boot_strata[!ch4]))
}
}

# prepare data for optimization
if (is.null(centered_colnames)) centered_colnames <- seq_len(ncol(data))
EM <- data[, centered_colnames, drop = FALSE]
Expand All @@ -104,25 +117,28 @@ estimate_weights <- function(data,
# bootstrapping
outboot <- if (is.null(n_boot_iteration)) {
boot_seed <- NULL
boot_strata <- NULL
boot_strata_out <- NULL
NULL
} else {
# Make sure to leave '.Random.seed' as-is on exit
old_seed <- globalenv()$.Random.seed
on.exit(suspendInterrupts(set_random_seed(old_seed)))
set.seed(set_seed_boot)

rowid_in_data <- which(!ind)
arms <- factor(data$ARM[rowid_in_data])
if (!is.null(boot_strata)) {
use_strata <- interaction(data[!ind, boot_strata])
} else {
use_strata <- rep(1, nrow(EM))
}
boot_statistic <- function(d, w) optimise_weights(d[w, ], par = alpha, method = method, ...)$wt[, 1]
boot_out <- boot(EM, statistic = boot_statistic, R = n_boot_iteration, strata = arms)
boot_out <- boot::boot(EM, statistic = boot_statistic, R = n_boot_iteration, strata = use_strata)

boot_array <- array(dim = list(nrow(EM), 2, n_boot_iteration))
dimnames(boot_array) <- list(sampled_patient = NULL, c("rowid", "weight"), bootstrap_iteration = NULL)
boot_array[, 1, ] <- t(boot.array(boot_out, TRUE))
boot_array[, 2, ] <- t(boot_out$t)
boot_seed <- boot_out$seed
boot_strata <- boot_out$strata
boot_strata_out <- boot_out$strata
boot_array
}

Expand All @@ -144,7 +160,7 @@ estimate_weights <- function(data,
opt = opt1$opt,
boot = outboot,
boot_seed = boot_seed,
boot_strata = boot_strata,
boot_strata = boot_strata_out,
rows_with_missing = rows_with_missing
)

Expand Down
7 changes: 7 additions & 0 deletions man/estimate_weights.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 34 additions & 0 deletions tests/testthat/test-matching.R
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,40 @@ test_that("estimate_weights works as expected with bootstrapping", {
})


test_that("estimate_weights works as expected with alternative bootstrap strata", {
load(system.file("extdata", "ipd.rda", package = "maicplus", mustWork = TRUE))
load(system.file("extdata", "agd.rda", package = "maicplus", mustWork = TRUE))
ipd_centered <- center_ipd(ipd = ipd, agd = agd)
centered_colnames <- paste0(c("AGE", "AGE_SQUARED", "ECOG0", "SMOKE", "N_PR_THER_MEDIAN"), "_CENTERED")
expect_output(
result <- estimate_weights(
data = ipd_centered,
centered_colnames = centered_colnames,
n_boot_iteration = 3,
set_seed = 999,
trace = 2,
boot_strata = c("ARM", "SEX")
),
"converged"
)

expect_s3_class(result, "maicplus_estimate_weights")
expect_equal(sum(result$data$weights), 206.83843133)

expect_error(
result <- estimate_weights(
data = ipd_centered,
centered_colnames = centered_colnames,
n_boot_iteration = 3,
set_seed = 999,
trace = 2,
boot_strata = "FISH"
),
"boot_strata are not in data"
)
})


test_that("estimate_weights works when the input is a tibble", {
skip_if_not_installed("tibble")
load(system.file("extdata", "ipd.rda", package = "maicplus", mustWork = TRUE))
Expand Down
Loading