diff --git a/R/sim_pw_surv.R b/R/sim_pw_surv.R index a9c604fc..6d4b3a69 100644 --- a/R/sim_pw_surv.R +++ b/R/sim_pw_surv.R @@ -142,6 +142,13 @@ sim_pw_surv <- function( duration = rep(100, 2), rate = rep(.001, 2) )) { + # Enforce consistent treatment names + treatments <- unique(c(block, fail_rate$treatment, dropout_rate$treatment)) + stopifnot( + treatments %in% block, + treatments %in% fail_rate$treatment, + treatments %in% dropout_rate$treatment + ) # Start table by generating stratum and enrollment times x <- data.table(stratum = sample( x = stratum$stratum, diff --git a/tests/testthat/test-double_programming_simPWSurv.R b/tests/testthat/test-double_programming_simPWSurv.R index ef7fa4d6..d8460b62 100644 --- a/tests/testthat/test-double_programming_simPWSurv.R +++ b/tests/testthat/test-double_programming_simPWSurv.R @@ -127,3 +127,30 @@ zevent <- dplyr::bind_rows(rate00, rate01, rate10, rate11) testthat::test_that("The actual number of events changes by changing total sample size", { expect_false(unique(xevent$event == zevent$event)) }) + +testthat::test_that("sim_pw_surv() fails early with mismatched treatment names", { + block <- c(rep("x", 2), rep("y", 2)) + fail_rate <- data.frame( + stratum = rep("All", 4), + period = rep(1:2, 2), + treatment = c(rep("x", 2), rep("y", 2)), + duration = rep(c(3, 1), 2), + rate = log(2) / c(9, 9, 9, 18) + ) + dropout_rate <- data.frame( + stratum = rep("All", 2), + period = rep(1, 2), + treatment = c("x", "y"), + duration = rep(100, 2), + rate = rep(0.001, 2) + ) + + expect_error(sim_pw_surv(block = block)) + expect_error(sim_pw_surv(fail_rate = fail_rate)) + expect_error(sim_pw_surv(dropout_rate = dropout_rate)) + # works as long as treatment names are consistent + expect_silent( + xy <- sim_pw_surv(block = block, fail_rate = fail_rate, dropout_rate = dropout_rate) + ) + expect_identical(sort(unique(xy$treatment)), c("x", "y")) +})